Compare commits

..

7 Commits

Author SHA1 Message Date
Michel Aractingi
14490148f3 added tdmpc2 to policy factory; shape fixes in tdmpc2 2024-11-26 11:58:29 +00:00
Michel Aractingi
16edbbdeee fixes and updated comments 2024-11-26 09:46:59 +00:00
Michel Aractingi
15090c2544 config comments 2024-11-25 09:51:33 +00:00
Michel Aractingi
166c1fc776 updated configuration parameters 2024-11-22 17:11:47 +00:00
Michel Aractingi
31984645da simplified estimate_value function in policy 2024-11-21 17:03:30 +00:00
Michel Aractingi
c41ec08ec1 remove self.model_target and added a target q ensemble only without the need to copy the
entire policy
2024-11-21 15:00:03 +00:00
Michel Aractingi
a146544765 added new implementation of tdmpc2 2024-11-20 17:30:19 +00:00
167 changed files with 6397 additions and 23228 deletions

View File

@@ -21,7 +21,7 @@ Provide a simple way for the reviewer to try out your changes.
Examples:
```bash
pytest -sx tests/test_stuff.py::test_something
DATA_DIR=tests/data pytest -sx tests/test_stuff.py::test_something
```
```bash
python lerobot/scripts/train.py --some.option=true

View File

@@ -7,8 +7,10 @@ on:
schedule:
- cron: "0 2 * * *"
# env:
env:
DATA_DIR: tests/data
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
jobs:
run_all_tests_cpu:
name: CPU
@@ -28,9 +30,13 @@ jobs:
working-directory: /lerobot
steps:
- name: Tests
env:
DATA_DIR: tests/data
run: pytest -v --cov=./lerobot --disable-warnings tests
- name: Tests end-to-end
env:
DATA_DIR: tests/data
run: make test-end-to-end

View File

@@ -50,7 +50,7 @@ jobs:
uses: actions/checkout@v3
- name: Install poetry
run: pipx install "poetry<2.0.0"
run: pipx install poetry
- name: Poetry check
run: poetry check
@@ -64,7 +64,7 @@ jobs:
uses: actions/checkout@v3
- name: Install poetry
run: pipx install "poetry<2.0.0"
run: pipx install poetry
- name: Install poetry-relax
run: poetry self add poetry-relax

View File

@@ -29,6 +29,7 @@ jobs:
name: Pytest
runs-on: ubuntu-latest
env:
DATA_DIR: tests/data
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
@@ -69,6 +70,7 @@ jobs:
name: Pytest (minimal install)
runs-on: ubuntu-latest
env:
DATA_DIR: tests/data
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
@@ -101,39 +103,40 @@ jobs:
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
&& rm -rf tests/outputs outputs
# TODO(aliberts, rcadene): redesign after v2 migration / removing hydra
# end-to-end:
# name: End-to-end
# runs-on: ubuntu-latest
# env:
# MUJOCO_GL: egl
# steps:
# - uses: actions/checkout@v4
# with:
# lfs: true # Ensure LFS files are pulled
# - name: Install apt dependencies
# # portaudio19-dev is needed to install pyaudio
# run: |
# sudo apt-get update && \
# sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
end-to-end:
name: End-to-end
runs-on: ubuntu-latest
env:
DATA_DIR: tests/data
MUJOCO_GL: egl
steps:
- uses: actions/checkout@v4
with:
lfs: true # Ensure LFS files are pulled
# - name: Install poetry
# run: |
# pipx install poetry && poetry config virtualenvs.in-project true
# echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
- name: Install apt dependencies
# portaudio19-dev is needed to install pyaudio
run: |
sudo apt-get update && \
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
# - name: Set up Python 3.10
# uses: actions/setup-python@v5
# with:
# python-version: "3.10"
# cache: "poetry"
- name: Install poetry
run: |
pipx install poetry && poetry config virtualenvs.in-project true
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
# - name: Install poetry dependencies
# run: |
# poetry install --all-extras
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "poetry"
# - name: Test end-to-end
# run: |
# make test-end-to-end \
# && rm -rf outputs
- name: Install poetry dependencies
run: |
poetry install --all-extras
- name: Test end-to-end
run: |
make test-end-to-end \
&& rm -rf outputs

View File

@@ -3,7 +3,7 @@ default_language_version:
python: python3.10
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v4.6.0
hooks:
- id: check-added-large-files
- id: debug-statements
@@ -14,12 +14,11 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/asottile/pyupgrade
rev: v3.19.0
rev: v3.16.0
hooks:
- id: pyupgrade
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.2
rev: v0.5.2
hooks:
- id: ruff
args: [--fix]
@@ -33,6 +32,6 @@ repos:
- "--check"
- "--no-update"
- repo: https://github.com/gitleaks/gitleaks
rev: v8.21.2
rev: v8.18.4
hooks:
- id: gitleaks

View File

@@ -267,7 +267,7 @@ We use `pytest` in order to run the tests. From the root of the
repository, here's how to run tests with `pytest` for the library:
```bash
python -m pytest -sv ./tests
DATA_DIR="tests/data" python -m pytest -sv ./tests
```

View File

@@ -68,7 +68,7 @@
### Acknowledgment
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
@@ -153,12 +153,10 @@ python lerobot/scripts/visualize_dataset.py \
--episode-index 0
```
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
or from a dataset in a local folder with the root `DATA_DIR` environment variable (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
```bash
python lerobot/scripts/visualize_dataset.py \
DATA_DIR='./my_local_data_dir' python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--root ./my_local_data_dir \
--local-files-only 1 \
--episode-index 0
```
@@ -210,10 +208,12 @@ dataset attributes:
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
- hf_dataset stored using Hugging Face datasets library serialization to parquet
- videos are stored in mp4 format to save space
- metadata are stored in plain json/jsonl files
- videos are stored in mp4 format to save space or png files
- episode_data_index saved using `safetensor` tensor serialization format
- stats saved using `safetensor` tensor serialization format
- info are saved using JSON
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can use the `local_files_only` argument and specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can set the `DATA_DIR` environment variable to your root dataset folder as illustrated in the above section on dataset visualization.
### Evaluate a pretrained policy

View File

@@ -21,7 +21,7 @@ How to decode videos?
## Variables
**Image content & size**
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
For these reasons, we run this benchmark on four representative datasets:
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
@@ -63,7 +63,7 @@ This of course is affected by the `-g` parameter during encoding, which specifie
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
Additionally, because some policies might request single timestamps that are a few frames appart, we also have the following scenario:
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
@@ -85,8 +85,8 @@ However, due to how video decoding is implemented with `pyav`, we don't have acc
**Average Structural Similarity Index Measure (higher is better)**
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
One aspect that can't be measured here with those metrics is the compatibility of the encoding accross platforms, in particular on web browser, for visualization purposes.
h264, h265 and AV1 are all commonly used codecs and should not be pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
- `yuv420p` is more widely supported across various platforms, including web browsers.
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
@@ -116,7 +116,7 @@ Additional encoding parameters exist that are not included in this benchmark. In
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
See the documentation mentioned above for more detailled info on these settings and for a more comprehensive list of other parameters.
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
- `torchaudio`

View File

@@ -32,11 +32,7 @@ import numpy as np
import pandas as pd
import PIL
import torch
from skimage.metrics import (
mean_squared_error,
peak_signal_noise_ratio,
structural_similarity,
)
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@@ -85,9 +81,7 @@ def get_directory_size(directory: Path) -> int:
return total_size
def load_original_frames(
imgs_dir: Path, timestamps: list[float], fps: int
) -> torch.Tensor:
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
frames = []
for ts in timestamps:
idx = int(ts * fps)
@@ -100,11 +94,7 @@ def load_original_frames(
def save_decoded_frames(
imgs_dir: Path,
save_dir: Path,
frames: torch.Tensor,
timestamps: list[float],
fps: int,
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
) -> None:
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
return
@@ -114,10 +104,7 @@ def save_decoded_frames(
idx = int(ts * fps)
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
shutil.copyfile(
imgs_dir / f"frame_{idx:06d}.png",
save_dir / f"frame_{idx:06d}_original.png",
)
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
@@ -129,17 +116,11 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
hf_dataset = dataset.hf_dataset.with_format(None)
# We only save images from the first camera
img_keys = [
key for key in hf_dataset.features if key.startswith("observation.image")
]
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
imgs_dataset = hf_dataset.select_columns(img_keys[0])
for i, item in enumerate(
tqdm(
imgs_dataset,
desc=f"saving {dataset.repo_id} first episode images",
leave=False,
)
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
):
img = item[img_keys[0]]
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
@@ -148,9 +129,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
break
def sample_timestamps(
timestamps_mode: str, ep_num_images: int, fps: int
) -> list[float]:
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
# Start at 5 to allow for 2_frames_4_space and 6_frames
idx = random.randint(5, ep_num_images - 1)
match timestamps_mode:
@@ -175,9 +154,7 @@ def decode_video_frames(
backend: str,
) -> torch.Tensor:
if backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend
)
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise NotImplementedError(backend)
@@ -204,9 +181,7 @@ def benchmark_decoding(
}
with time_benchmark:
frames = decode_video_frames(
video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend
)
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
with time_benchmark:
@@ -215,18 +190,12 @@ def benchmark_decoding(
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
for i in range(num_frames):
result["mse_values"].append(
mean_squared_error(original_frames_np[i], frames_np[i])
)
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
result["psnr_values"].append(
peak_signal_noise_ratio(
original_frames_np[i], frames_np[i], data_range=1.0
)
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
)
result["ssim_values"].append(
structural_similarity(
original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0
)
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
)
if save_frames and sample == 0:
@@ -246,9 +215,7 @@ def benchmark_decoding(
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_sample, i) for i in range(num_samples)]
for future in tqdm(
as_completed(futures), total=num_samples, desc="samples", leave=False
):
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
result = future.result()
load_times_video_ms.append(result["load_time_video_ms"])
load_times_images_ms.append(result["load_time_images_ms"])
@@ -299,7 +266,7 @@ def benchmark_encoding_decoding(
)
ep_num_images = dataset.episode_data_index["to"][0].item()
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
width, height = tuple(dataset[0][dataset.camera_keys[0]].shape[-2:])
num_pixels = width * height
video_size_bytes = video_path.stat().st_size
images_size_bytes = get_directory_size(imgs_dir)
@@ -308,13 +275,9 @@ def benchmark_encoding_decoding(
random.seed(seed)
benchmark_table = []
for timestamps_mode in tqdm(
decoding_cfg["timestamps_modes"],
desc="decodings (timestamps_modes)",
leave=False,
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
):
for backend in tqdm(
decoding_cfg["backends"], desc="decodings (backends)", leave=False
):
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
benchmark_row = benchmark_decoding(
imgs_dir,
video_path,
@@ -392,23 +355,14 @@ def main(
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
# We only use the first episode
save_first_episode(imgs_dir, dataset)
for key, values in tqdm(
encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False
):
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
for value in tqdm(values, desc=f"encodings ({key})", leave=False):
encoding_cfg = BASE_ENCODING.copy()
encoding_cfg["vcodec"] = video_codec
encoding_cfg["pix_fmt"] = pixel_format
encoding_cfg[key] = value
args_path = Path(
"_".join(str(value) for value in encoding_cfg.values())
)
video_path = (
output_dir
/ "videos"
/ args_path
/ f"{repo_id.replace('/', '_')}.mp4"
)
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
benchmark_table += benchmark_encoding_decoding(
dataset,
video_path,
@@ -434,9 +388,7 @@ def main(
# Concatenate all results
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
concatenated_df = pd.concat(df_list, ignore_index=True)
concatenated_path = (
output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
)
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
concatenated_df.to_csv(concatenated_path, header=True, index=False)

View File

@@ -1,18 +0,0 @@
import socket
def check_port(host, port):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
s.connect((host, port))
print(f"Connection successful to {host}:{port}!")
except Exception as e:
print(f"Connection failed to {host}:{port}: {e}")
finally:
s.close()
if __name__ == "__main__":
host = "127.0.0.1" # or "localhost"
port = 51350
check_port(host, port)

View File

@@ -1,11 +0,0 @@
FROM huggingface/lerobot-gpu:latest
RUN apt-get update && apt-get install -y --no-install-recommends \
libvulkan1 vulkan-tools \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
RUN pip install --upgrade --no-cache-dir pip
RUN pip install --no-cache-dir ".[mani-skill]"
# Set EGL as the rendering backend for MuJoCo
ENV MUJOCO_GL="egl"

View File

@@ -1,31 +1,25 @@
# Using the [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot
This tutorial explains how to use [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot.
## A. Source the parts
## Source the parts
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
## B. Install LeRobot
## Install LeRobot
On your computer:
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
```bash
mkdir -p ~/miniconda3
# Linux:
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
# Mac M-series:
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
# Mac Intel:
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
~/miniconda3/bin/conda init bash
```
2. Restart shell or `source ~/.bashrc` (*Mac*: `source ~/.bash_profile`) or `source ~/.zshrc` if you're using zshell
2. Restart shell or `source ~/.bashrc`
3. Create and activate a fresh conda environment for lerobot
```bash
@@ -42,30 +36,23 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
cd ~/lerobot && pip install -e ".[feetech]"
```
*For Linux only (not Mac)*: install extra dependencies for recording datasets:
For Linux only (not Mac), install extra dependencies for recording datasets:
```bash
conda install -y -c conda-forge ffmpeg
pip uninstall -y opencv-python
conda install -y -c conda-forge "opencv>=4.10.0"
```
## C. Configure the motors
## Configure the motors
### 1. Find the USB ports associated to each arm
Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the use of our scripts below.
Designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm.
#### a. Run the script to find ports
Follow Step 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I), which illustrates the use of our scripts below.
To find the port for each bus servo adapter, run the utility script:
**Find USB ports associated to your arms**
To find the correct ports for each arm, run the utility script twice:
```bash
python lerobot/scripts/find_motors_bus_port.py
```
#### b. Example outputs
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
```
Finding all available ports for the MotorBus.
@@ -77,6 +64,7 @@ Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
Reconnect the usb cable.
```
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
```
Finding all available ports for the MotorBus.
@@ -89,20 +77,13 @@ The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
Reconnect the usb cable.
```
#### c. Troubleshooting
On Linux, you might need to give access to the USB ports by running:
Troubleshooting: On Linux, you might need to give access to the USB ports by running:
```bash
sudo chmod 666 /dev/ttyACM0
sudo chmod 666 /dev/ttyACM1
```
#### d. Update YAML file
Now that you have the ports, modify the *port* sections in `so100.yaml`
### 2. Configure the motors
#### a. Set IDs for all 12 motors
**Configure your motors**
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
```bash
python lerobot/scripts/configure_motor.py \
@@ -113,7 +94,7 @@ python lerobot/scripts/configure_motor.py \
--ID 1
```
*Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).*
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
Then unplug your motor and plug the second motor and set its ID to 2.
```bash
@@ -127,25 +108,23 @@ python lerobot/scripts/configure_motor.py \
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
**Remove the gears of the 6 leader motors**
Follow step 2 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
#### b. Remove the gears of the 6 leader motors
Follow step 2 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=248). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
#### c. Add motor horn to all 12 motors
Follow step 3 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=569). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
**Add motor horn to the motors**
Follow step 3 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
## D. Assemble the arms
## Assemble the arms
Follow step 4 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=610). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
## E. Calibrate
## Calibrate
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
#### a. Manual calibration of follower arm
/!\ Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
**Manual calibration of follower arm**
/!\ Contrarily to step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
You will need to move the follower arm to these positions sequentially:
@@ -160,8 +139,8 @@ python lerobot/scripts/control_robot.py calibrate \
--robot-overrides '~cameras' --arms main_follower
```
#### b. Manual calibration of leader arm
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
**Manual calibration of leader arm**
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
| 1. Zero position | 2. Rotated position | 3. Rest position |
|---|---|---|
@@ -174,7 +153,7 @@ python lerobot/scripts/control_robot.py calibrate \
--robot-overrides '~cameras' --arms main_leader
```
## F. Teleoperate
## Teleoperate
**Simple teleop**
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
@@ -186,14 +165,14 @@ python lerobot/scripts/control_robot.py teleoperate \
```
#### a. Teleop with displaying cameras
**Teleop with displaying cameras**
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
```bash
python lerobot/scripts/control_robot.py teleoperate \
--robot-path lerobot/configs/robot/so100.yaml
```
## G. Record a dataset
## Record a dataset
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
@@ -213,6 +192,7 @@ Record 2 episodes and upload your dataset to the hub:
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/so100.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/so100_test \
--tags so100 tutorial \
--warmup-time-s 5 \
@@ -222,7 +202,7 @@ python lerobot/scripts/control_robot.py record \
--push-to-hub 1
```
## H. Visualize a dataset
## Visualize a dataset
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
```bash
@@ -232,25 +212,27 @@ echo ${HF_USER}/so100_test
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--root data \
--repo-id ${HF_USER}/so100_test
```
## I. Replay an episode
## Replay an episode
Now try to replay the first episode on your robot:
```bash
python lerobot/scripts/control_robot.py replay \
DATA_DIR=data python lerobot/scripts/control_robot.py replay \
--robot-path lerobot/configs/robot/so100.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/so100_test \
--episode 0
```
## J. Train a policy
## Train a policy
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
python lerobot/scripts/train.py \
DATA_DIR=data python lerobot/scripts/train.py \
dataset_repo_id=${HF_USER}/so100_test \
policy=act_so100_real \
env=so100_real \
@@ -266,16 +248,18 @@ Let's explain it:
3. We provided an environment as argument with `env=so100_real`. This loads configurations from [`lerobot/configs/env/so100_real.yaml`](../lerobot/configs/env/so100_real.yaml).
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
## K. Evaluate your policy
## Evaluate your policy
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
```bash
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/so100.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/eval_act_so100_test \
--tags so100 tutorial eval \
--warmup-time-s 5 \
@@ -289,7 +273,7 @@ As you can see, it's almost the same command as previously used to record your t
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_so100_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_so100_test`).
## L. More Information
## More
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.

View File

@@ -192,6 +192,7 @@ Record 2 episodes and upload your dataset to the hub:
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/moss.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/moss_test \
--tags moss tutorial \
--warmup-time-s 5 \
@@ -211,6 +212,7 @@ echo ${HF_USER}/moss_test
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--root data \
--repo-id ${HF_USER}/moss_test
```
@@ -218,9 +220,10 @@ python lerobot/scripts/visualize_dataset_html.py \
Now try to replay the first episode on your robot:
```bash
python lerobot/scripts/control_robot.py replay \
DATA_DIR=data python lerobot/scripts/control_robot.py replay \
--robot-path lerobot/configs/robot/moss.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/moss_test \
--episode 0
```
@@ -229,7 +232,7 @@ python lerobot/scripts/control_robot.py replay \
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
python lerobot/scripts/train.py \
DATA_DIR=data python lerobot/scripts/train.py \
dataset_repo_id=${HF_USER}/moss_test \
policy=act_moss_real \
env=moss_real \
@@ -245,6 +248,7 @@ Let's explain it:
3. We provided an environment as argument with `env=moss_real`. This loads configurations from [`lerobot/configs/env/moss_real.yaml`](../lerobot/configs/env/moss_real.yaml).
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
@@ -255,6 +259,7 @@ You can use the `record` function from [`lerobot/scripts/control_robot.py`](../l
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/moss.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/eval_act_moss_test \
--tags moss tutorial eval \
--warmup-time-s 5 \

View File

@@ -1,94 +0,0 @@
# Training a HIL-SERL Reward Classifier with LeRobot
This tutorial provides step-by-step instructions for training a reward classifier using LeRobot.
---
## Training Script Overview
LeRobot includes a ready-to-use training script located at [`lerobot/scripts/train_hilserl_classifier.py`](../../lerobot/scripts/train_hilserl_classifier.py). Here's an outline of its workflow:
1. **Configuration Loading**
The script uses Hydra to load a configuration file for subsequent steps. (Details on Hydra follow below.)
2. **Dataset Initialization**
It loads a `LeRobotDataset` containing images and rewards. To optimize performance, a weighted random sampler is used to balance class sampling.
3. **Classifier Initialization**
A lightweight classification head is built on top of a frozen, pretrained image encoder from HuggingFace. The classifier outputs either:
- A single probability (binary classification), or
- Logits (multi-class classification).
4. **Training Loop Execution**
The script performs:
- Forward and backward passes,
- Optimization steps,
- Periodic logging, evaluation, and checkpoint saving.
---
## Configuring with Hydra
For detailed information about Hydra usage, refer to [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md). However, note that training the reward classifier differs slightly and requires a separate configuration file.
### Config File Setup
The default `default.yaml` cannot launch the reward classifier training directly. Instead, you need a configuration file like [`lerobot/configs/policy/hilserl_classifier.yaml`](../../lerobot/configs/policy/hilserl_classifier.yaml), with the following adjustment:
Replace the `dataset_repo_id` field with the identifier for your dataset, which contains images and sparse rewards:
```yaml
# Example: lerobot/configs/policy/reward_classifier.yaml
dataset_repo_id: "my_dataset_repo_id"
## Typical logs and metrics
```
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
After that, you will see training log like this one:
```
[2024-11-29 18:26:36,999][root][INFO] -
Epoch 5/5
Training: 82%|██████████████████████████████████████████████████████████████████████████████▋ | 91/111 [00:50<00:09, 2.04it/s, loss=0.2999, acc=69.99%]
```
or evaluation log like:
```
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:20<00:00, 1.37it/s]
```
### Metrics Tracking with Weights & Biases (WandB)
If `wandb.enable` is set to `true`, the training and evaluation logs will also be saved in WandB. This allows you to track key metrics in real-time, including:
- **Training Metrics**:
- `train/accuracy`
- `train/loss`
- `train/dataloading_s`
- **Evaluation Metrics**:
- `eval/accuracy`
- `eval/loss`
- `eval/eval_s`
#### Additional Features
You can also log sample predictions during evaluation. Each logged sample will include:
- The **input image**.
- The **predicted label**.
- The **true label**.
- The **classifier's "confidence" (logits/probability)**.
These logs can be useful for diagnosing and debugging performance issues.
#### Generate protobuf files
```bash
python -m grpc_tools.protoc \
-I lerobot/scripts/server \
--python_out=lerobot/scripts/server \
--grpc_python_out=lerobot/scripts/server \
lerobot/scripts/server/hilserl.proto
```

View File

@@ -3,128 +3,78 @@ This script demonstrates the use of `LeRobotDataset` class for handling and proc
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
Features included in this script:
- Viewing a dataset's metadata and exploring its properties.
- Loading an existing dataset from the hub or a subset of it.
- Accessing frames by episode number.
- Loading a dataset and accessing its properties.
- Filtering data by episode number.
- Converting tensor data for visualization.
- Saving video files from dataset frames.
- Using advanced dataset features like timestamp-based frame selection.
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
The script ends with examples of how to batch process data using PyTorch's DataLoader.
"""
from pathlib import Path
from pprint import pprint
import imageio
import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
# We ported a number of existing datasets ourselves, use this to see the list:
print("List of available datasets:")
pprint(lerobot.available_datasets)
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
hub_api = HfApi()
repo_ids = [
info.id
for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])
]
pprint(repo_ids)
# Let's take one for this example
repo_id = "lerobot/pusht"
# Or simply explore them in your web browser directly at:
# https://huggingface.co/datasets?other=LeRobot
# Let's take this one for this example
repo_id = "lerobot/aloha_mobile_cabinet"
# We can have a look and fetch its metadata to know more about it:
ds_meta = LeRobotDatasetMetadata(repo_id)
# By instantiating just this class, you can quickly access useful information about the content and the
# structure of the dataset without downloading the actual data yet (only metadata files — which are
# lightweight).
print(f"Total number of episodes: {ds_meta.total_episodes}")
print(
f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}"
)
print(f"Frames per second used during data collection: {ds_meta.fps}")
print(f"Robot type: {ds_meta.robot_type}")
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
print("Tasks:")
print(ds_meta.tasks)
print("Features:")
pprint(ds_meta.features)
# You can also get a short summary by simply printing the object:
print(ds_meta)
# You can then load the actual dataset from the hub.
# Either load any subset of episodes:
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
# And see how many frames you have:
print(f"Selected episodes: {dataset.episodes}")
print(f"Number of episodes selected: {dataset.num_episodes}")
print(f"Number of frames selected: {dataset.num_frames}")
# Or simply load the entire dataset:
# You can easily load a dataset from a Hugging Face repository
dataset = LeRobotDataset(repo_id)
print(f"Number of episodes selected: {dataset.num_episodes}")
print(f"Number of frames selected: {dataset.num_frames}")
# The previous metadata class is contained in the 'meta' attribute of the dataset:
print(dataset.meta)
# LeRobotDataset actually wraps an underlying Hugging Face dataset
# (see https://huggingface.co/docs/datasets for more information).
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset
# (see https://huggingface.co/docs/datasets/index for more information).
print(dataset)
print(dataset.hf_dataset)
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset.
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
# frame indices associated to the first episode:
# And provides additional utilities for robotics and compatibility with Pytorch
print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.camera_keys=}\n")
# Access frame indexes associated to first episode
episode_index = 0
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
# Then we grab all the image frames from the first camera:
camera_key = dataset.meta.camera_keys[0]
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset. Here we grab all the image frames.
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
# The objects returned by the dataset are all torch.Tensors
print(type(frames[0]))
print(frames[0].shape)
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize
# them, we convert to uint8 in range [0,255]
frames = [(frame * 255).type(torch.uint8) for frame in frames]
# and to channel last (h,w,c).
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
# We can compare this shape with the information available for that feature
pprint(dataset.features[camera_key])
# In particular:
print(dataset.features[camera_key]["shape"])
# The shape is in (h, w, c) which is a more universal format.
# Finally, we save the frames to a mp4 video for visualization.
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
# For many machine learning applications we need to load the history of past observations or trajectories of
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
# differences with the current loaded frame. For instance:
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
camera_key: [-1, -0.5, -0.20, 0],
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
"observation.image": [-1, -0.5, -0.20, 0],
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / dataset.fps for t in range(64)],
}
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
# timestamp, you still get a valid timestamp.
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
print(f"{dataset[0]['action'].shape=}\n") # (64,c)
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
# PyTorch datasets.
@@ -134,9 +84,8 @@ dataloader = torch.utils.data.DataLoader(
batch_size=32,
shuffle=True,
)
for batch in dataloader:
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
print(f"{batch['observation.state'].shape=}") # (32, 5, c)
print(f"{batch['action'].shape=}") # (32, 64, c)
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w)
print(f"{batch['observation.state'].shape=}") # (32,8,c)
print(f"{batch['action'].shape=}") # (32,64,c)
break

View File

@@ -32,9 +32,7 @@ if torch.cuda.is_available():
print("GPU is available. Device set to:", device)
else:
device = torch.device("cpu")
print(
f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU."
)
print(f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU.")
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
policy.diffusion.num_inference_steps = 10

View File

@@ -31,24 +31,7 @@ delta_timestamps = {
# Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to supervise the policy.
"action": [
-0.1,
0.0,
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1.0,
1.1,
1.2,
1.3,
1.4,
],
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
@@ -57,7 +40,7 @@ dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
# For this example, no arguments need to be passed because the defaults are set up for PushT.
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig()
policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
policy.train()
policy.to(device)

View File

@@ -1,7 +1,7 @@
"""
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
transforms are applied to the observation images before they are returned in the dataset's __getitem__.
transforms are applied to the observation images before they are returned in the dataset's __get_item__.
"""
from pathlib import Path
@@ -10,17 +10,17 @@ from torchvision.transforms import ToPILImage, v2
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
dataset_repo_id = "lerobot/aloha_static_screw_driver"
dataset_repo_id = "lerobot/aloha_static_tape"
# Create a LeRobotDataset with no transformations
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
dataset = LeRobotDataset(dataset_repo_id)
# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`
# Get the index of the first observation in the first episode
first_idx = dataset.episode_data_index["from"][0].item()
# Get the frame corresponding to the first camera
frame = dataset[first_idx][dataset.meta.camera_keys[0]]
frame = dataset[first_idx][dataset.camera_keys[0]]
# Define the transformations
@@ -28,20 +28,15 @@ transforms = v2.Compose(
[
v2.ColorJitter(brightness=(0.5, 1.5)),
v2.ColorJitter(contrast=(0.5, 1.5)),
v2.ColorJitter(hue=(-0.1, 0.1)),
v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
]
)
# Create another LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset(
dataset_repo_id, episodes=[0], image_transforms=transforms
)
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)
# Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][
transformed_dataset.meta.camera_keys[0]
]
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
# Create a directory to store output images
output_dir = Path("outputs/image_transforms")

View File

@@ -29,7 +29,7 @@ For a visual walkthrough of the assembly process, you can refer to [this video t
## 2. Configure motors, calibrate arms, teleoperate your Koch v1.1
First, install the additional dependencies required for robots built with dynamixel motors like Koch v1.1 by running one of the following commands (make sure gcc is installed).
First, install the additional dependencies required for robots built with dynamixel motors like Koch v1.1 by running one of the following commands.
Using `pip`:
```bash
@@ -778,6 +778,7 @@ Now run this to record 2 episodes:
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/koch.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/koch_test \
--tags tutorial \
--warmup-time-s 5 \
@@ -786,7 +787,7 @@ python lerobot/scripts/control_robot.py record \
--num-episodes 2
```
This will write your dataset locally to `~/.cache/huggingface/lerobot/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
This will write your dataset locally to `{root}/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
You can look for other LeRobot datasets on the hub by searching for `LeRobot` tags: https://huggingface.co/datasets?other=LeRobot
@@ -839,6 +840,7 @@ In the coming months, we plan to release a foundational model for robotics. We a
You can visualize your dataset by running:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--root data \
--repo-id ${HF_USER}/koch_test
```
@@ -856,6 +858,7 @@ To replay the first episode of the dataset you just recorded, run the following
python lerobot/scripts/control_robot.py replay \
--robot-path lerobot/configs/robot/koch.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/koch_test \
--episode 0
```
@@ -868,7 +871,7 @@ Your robot should replicate movements similar to those you recorded. For example
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
python lerobot/scripts/train.py \
DATA_DIR=data python lerobot/scripts/train.py \
dataset_repo_id=${HF_USER}/koch_test \
policy=act_koch_real \
env=koch_real \
@@ -915,6 +918,7 @@ env:
It should match your dataset (e.g. `fps: 30`) and your robot (e.g. `state_dim: 6` and `action_dim: 6`). We are still working on simplifying this in future versions of `lerobot`.
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
@@ -987,6 +991,7 @@ To this end, you can use the `record` function from [`lerobot/scripts/control_ro
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/koch.yaml \
--fps 30 \
--root data \
--repo-id ${HF_USER}/eval_koch_test \
--tags tutorial eval \
--warmup-time-s 5 \
@@ -1005,6 +1010,7 @@ As you can see, it's almost the same command as previously used to record your t
You can then visualize your evaluation dataset by running the same command as before but with the new inference dataset as argument:
```bash
python lerobot/scripts/visualize_dataset.py \
--root data \
--repo-id ${HF_USER}/eval_koch_test
```

View File

@@ -128,6 +128,7 @@ Record one episode:
python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/stretch.yaml \
--fps 20 \
--root data \
--repo-id ${HF_USER}/stretch_test \
--tags stretch tutorial \
--warmup-time-s 3 \
@@ -145,6 +146,7 @@ Now try to replay this episode (make sure the robot's initial position is the sa
python lerobot/scripts/control_robot.py replay \
--robot-path lerobot/configs/robot/stretch.yaml \
--fps 20 \
--root data \
--repo-id ${HF_USER}/stretch_test \
--episode 0
```

View File

@@ -56,7 +56,7 @@ python lerobot/scripts/control_robot.py teleoperate \
--robot-overrides max_relative_target=5
```
By adding `--robot-overrides max_relative_target=5`, we override the default value for `max_relative_target` defined in `lerobot/configs/robot/aloha.yaml`. It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot-overrides max_relative_target=null` to the command line:
By adding `--robot-overrides max_relative_target=5`, we override the default value for `max_relative_target` defined in `lerobot/configs/robot/aloha.yaml`. It is expected to be `5` to limit the magnitude of the movement for more safety, but the teloperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot-overrides max_relative_target=null` to the command line:
```bash
python lerobot/scripts/control_robot.py teleoperate \
--robot-path lerobot/configs/robot/aloha.yaml \
@@ -84,6 +84,7 @@ python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/aloha.yaml \
--robot-overrides max_relative_target=null \
--fps 30 \
--root data \
--repo-id ${HF_USER}/aloha_test \
--tags aloha tutorial \
--warmup-time-s 5 \
@@ -103,6 +104,7 @@ echo ${HF_USER}/aloha_test
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--root data \
--repo-id ${HF_USER}/aloha_test
```
@@ -117,6 +119,7 @@ python lerobot/scripts/control_robot.py replay \
--robot-path lerobot/configs/robot/aloha.yaml \
--robot-overrides max_relative_target=null \
--fps 30 \
--root data \
--repo-id ${HF_USER}/aloha_test \
--episode 0
```
@@ -125,7 +128,7 @@ python lerobot/scripts/control_robot.py replay \
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
python lerobot/scripts/train.py \
DATA_DIR=data python lerobot/scripts/train.py \
dataset_repo_id=${HF_USER}/aloha_test \
policy=act_aloha_real \
env=aloha_real \
@@ -141,6 +144,7 @@ Let's explain it:
3. We provided an environment as argument with `env=aloha_real`. This loads configurations from [`lerobot/configs/env/aloha_real.yaml`](../lerobot/configs/env/aloha_real.yaml). Note: this yaml defines 18 dimensions for the `state_dim` and `action_dim`, corresponding to 18 motors, not 14 motors as used in previous Aloha work. This is because, we include the `shoulder_shadow` and `elbow_shadow` motors for simplicity.
4. We provided `device=cuda` since we are training on a Nvidia GPU.
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`.
@@ -152,6 +156,7 @@ python lerobot/scripts/control_robot.py record \
--robot-path lerobot/configs/robot/aloha.yaml \
--robot-overrides max_relative_target=null \
--fps 30 \
--root data \
--repo-id ${HF_USER}/eval_act_aloha_test \
--tags aloha tutorial eval \
--warmup-time-s 5 \

View File

@@ -14,10 +14,7 @@ from pathlib import Path
import torch
from huggingface_hub import snapshot_download
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
device = torch.device("cuda")
@@ -40,44 +37,29 @@ delta_timestamps = {
# Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to calculate the loss.
"action": [
-0.1,
0.0,
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1.0,
1.1,
1.2,
1.3,
1.4,
],
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
# Load the last 10% of episodes of the dataset as a validation set.
# - Load dataset metadata
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
# - Calculate train and val episodes
total_episodes = dataset_metadata.total_episodes
episodes = list(range(dataset_metadata.total_episodes))
num_train_episodes = math.floor(total_episodes * 90 / 100)
train_episodes = episodes[:num_train_episodes]
val_episodes = episodes[num_train_episodes:]
print(f"Number of episodes in full dataset: {total_episodes}")
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
# - Load train an val datasets
# - Load full dataset
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
# - Calculate train and val subsets
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
num_val_episodes = full_dataset.num_episodes - num_train_episodes
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
# - Get first frame index of the validation set
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
# - Load frames subset belonging to validation set using the `split` argument.
# It utilizes the `datasets` library's syntax for slicing datasets.
# For more information on the Slice API, please see:
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
train_dataset = LeRobotDataset(
"lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
)
val_dataset = LeRobotDataset(
"lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
)
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")

View File

@@ -1,228 +0,0 @@
import shutil
from pathlib import Path
import numpy as np
import torch
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
PUSHT_FEATURES = {
"observation.state": {
"dtype": "float32",
"shape": (2,),
"names": {
"axes": ["x", "y"],
},
},
"action": {
"dtype": "float32",
"shape": (2,),
"names": {
"axes": ["x", "y"],
},
},
"next.reward": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
"next.success": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
"observation.environment_state": {
"dtype": "float32",
"shape": (16,),
"names": [
"keypoints",
],
},
"observation.image": {
"dtype": None,
"shape": (3, 96, 96),
"names": [
"channel",
"height",
"width",
],
},
}
def build_features(mode: str) -> dict:
features = PUSHT_FEATURES
if mode == "keypoints":
features.pop("observation.image")
else:
features.pop("observation.environment_state")
features["observation.image"]["dtype"] = mode
return features
def load_raw_dataset(zarr_path: Path):
try:
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
except ModuleNotFoundError as e:
print(
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
)
raise e
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
return zarr_data
def calculate_coverage(zarr_data):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
except ModuleNotFoundError as e:
print(
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
)
raise e
block_pos = zarr_data["state"][:, 2:4]
block_angle = zarr_data["state"][:, 4]
num_frames = len(block_pos)
coverage = np.zeros((num_frames,))
# 8 keypoints with 2 coords each
keypoints = np.zeros((num_frames, 16))
# Set x, y, theta (in radians)
goal_pos_angle = np.array([256, 256, np.pi / 4])
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
for i in range(num_frames):
space = pymunk.Space()
space.gravity = 0, 0
space.damping = 0
# Add walls.
walls = [
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
]
space.add(*walls)
block_body, block_shapes = PushTEnv.add_tee(
space, block_pos[i].tolist(), block_angle[i].item()
)
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area
coverage[i] = intersection_area / goal_area
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
return coverage, keypoints
def calculate_success(coverage: float, success_threshold: float):
return coverage > success_threshold
def calculate_reward(coverage: float, success_threshold: float):
return np.clip(coverage / success_threshold, 0, 1)
def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = True):
if mode not in ["video", "image", "keypoints"]:
raise ValueError(mode)
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
if not raw_dir.exists():
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
zarr_data = load_raw_dataset(zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr")
env_state = zarr_data["state"][:]
agent_pos = env_state[:, :2]
action = zarr_data["action"][:]
image = zarr_data["img"] # (b, h, w, c)
episode_data_index = {
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
"to": zarr_data.meta["episode_ends"],
}
# Calculate success and reward based on the overlapping area
# of the T-object and the T-area.
coverage, keypoints = calculate_coverage(zarr_data)
success = calculate_success(coverage, success_threshold=0.95)
reward = calculate_reward(coverage, success_threshold=0.95)
features = build_features(mode)
dataset = LeRobotDataset.create(
repo_id=repo_id,
fps=10,
robot_type="2d pointer",
features=features,
image_writer_threads=4,
)
episodes = range(len(episode_data_index["from"]))
for ep_idx in episodes:
from_idx = episode_data_index["from"][ep_idx]
to_idx = episode_data_index["to"][ep_idx]
num_frames = to_idx - from_idx
for frame_idx in range(num_frames):
i = from_idx + frame_idx
frame = {
"action": torch.from_numpy(action[i]),
# Shift reward and success by +1 until the last item of the episode
"next.reward": reward[i + (frame_idx < num_frames - 1)],
"next.success": success[i + (frame_idx < num_frames - 1)],
}
frame["observation.state"] = torch.from_numpy(agent_pos[i])
if mode == "keypoints":
frame["observation.environment_state"] = torch.from_numpy(keypoints[i])
else:
frame["observation.image"] = torch.from_numpy(image[i])
dataset.add_frame(frame)
dataset.save_episode(task=PUSHT_TASK)
dataset.consolidate()
if push_to_hub:
dataset.push_to_hub()
if __name__ == "__main__":
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
repo_id = "lerobot/pusht"
modes = ["video", "image", "keypoints"]
# Uncomment if you want to try with a specific mode
# modes = ["video"]
# modes = ["image"]
# modes = ["keypoints"]
raw_dir = Path("data/lerobot-raw/pusht_raw")
for mode in modes:
if mode in ["image", "keypoints"]:
repo_id += f"_{mode}"
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
main(raw_dir, repo_id=repo_id, mode=mode)
# Uncomment if you want to load the local dataset and explore it
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
# breakpoint()

View File

@@ -181,12 +181,8 @@ available_real_world_datasets = [
"lerobot/usc_cloth_sim",
]
available_datasets = sorted(
set(
itertools.chain(
*available_datasets_per_env.values(), available_real_world_datasets
)
)
available_datasets = list(
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
)
# lists all available policies from `lerobot/common/policies`
@@ -228,13 +224,9 @@ available_policies_per_env = {
"dora_aloha_real": ["act_aloha_real"],
}
env_task_pairs = [
(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks
]
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
env_dataset_pairs = [
(env, dataset)
for env, datasets in available_datasets_per_env.items()
for dataset in datasets
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
]
env_dataset_policy_triplets = [
(env, dataset, policy)

View File

@@ -1,27 +0,0 @@
---
# For reference on dataset card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/datasetcard.md?plain=1
# Doc / guide: https://huggingface.co/docs/hub/datasets-cards
{{ card_data }}
---
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
## Dataset Description
{{ dataset_description | default("", true) }}
- **Homepage:** {{ url | default("[More Information Needed]", true)}}
- **Paper:** {{ paper | default("[More Information Needed]", true)}}
- **License:** {{ license | default("[More Information Needed]", true)}}
## Dataset Structure
{{ dataset_structure | default("[More Information Needed]", true)}}
## Citation
**BibTeX:**
```bibtex
{{ citation_bibtex | default("[More Information Needed]", true)}}
```

View File

@@ -19,6 +19,9 @@ from math import ceil
import einops
import torch
import tqdm
from datasets import Image
from lerobot.common.datasets.video_utils import VideoFrame
def get_stats_einops_patterns(dataset, num_workers=0):
@@ -36,29 +39,23 @@ def get_stats_einops_patterns(dataset, num_workers=0):
batch = next(iter(dataloader))
stats_patterns = {}
for key, feats_type in dataset.features.items():
# NOTE: skip language_instruction embedding in stats computation
if key == "language_instruction":
continue
for key in dataset.features:
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
# if isinstance(feats_type, (VideoFrame, Image)):
if key in dataset.meta.camera_keys:
if isinstance(feats_type, (VideoFrame, Image)):
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert (
c < h and c < w
), f"expect channel first images, but instead {batch[key].shape}"
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
# sanity check that images are float32 in range [0,1]
assert (
batch[key].dtype == torch.float32
), f"expect torch.float32, but instead {batch[key].dtype=}"
assert (
batch[key].max() <= 1
), f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert (
batch[key].min() >= 0
), f"expect pixels greater than 1, but instead {batch[key].min()=}"
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
@@ -66,7 +63,7 @@ def get_stats_einops_patterns(dataset, num_workers=0):
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1"
else:
raise ValueError(f"{key}, {batch[key].shape}")
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
return stats_patterns
@@ -106,11 +103,7 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(
dataloader,
total=ceil(max_num_samples / batch_size),
desc="Compute mean, min, max",
)
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
@@ -125,16 +118,9 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
mean[key] = (
mean[key]
+ this_batch_size * (batch_mean - mean[key]) / running_item_count
)
max[key] = torch.maximum(
max[key], einops.reduce(batch[key], pattern, "max")
)
min[key] = torch.minimum(
min[key], einops.reduce(batch[key], pattern, "min")
)
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
if i == ceil(max_num_samples / batch_size) - 1:
break
@@ -143,9 +129,7 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(
dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std"
)
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
@@ -159,9 +143,7 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = (
std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
)
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
if i == ceil(max_num_samples / batch_size) - 1:
break
@@ -193,51 +175,39 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
"""
data_keys = set()
for dataset in ls_datasets:
data_keys.update(dataset.meta.stats.keys())
data_keys.update(dataset.stats.keys())
stats = {k: {} for k in data_keys}
for data_key in data_keys:
for stat_key in ["min", "max"]:
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
stats[data_key][stat_key] = einops.reduce(
torch.stack(
[
ds.meta.stats[data_key][stat_key]
for ds in ls_datasets
if data_key in ds.meta.stats
],
dim=0,
),
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
"n ... -> ...",
stat_key,
)
total_samples = sum(
d.num_frames for d in ls_datasets if data_key in d.meta.stats
)
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
# dataset, then divide by total_samples to get the overall "mean".
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["mean"] = sum(
d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
for d in ls_datasets
if data_key in d.meta.stats
if data_key in d.stats
)
# The derivation for standard deviation is a little more involved but is much in the same spirit as
# the computation of the mean.
# Given two sets of data where the statistics are known:
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["std"] = torch.sqrt(
sum(
(
d.meta.stats[data_key]["std"] ** 2
+ (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
)
* (d.num_frames / total_samples)
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
* (d.num_samples / total_samples)
for d in ls_datasets
if data_key in d.meta.stats
if data_key in d.stats
)
)
return stats

View File

@@ -74,25 +74,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
image_transforms = None
if cfg.training.image_transforms.enable:
default_tf = OmegaConf.create(
{
"brightness": {"weight": 0.0, "min_max": None},
"contrast": {"weight": 0.0, "min_max": None},
"saturation": {"weight": 0.0, "min_max": None},
"hue": {"weight": 0.0, "min_max": None},
"sharpness": {"weight": 0.0, "min_max": None},
"max_num_transforms": None,
"random_order": False,
"image_size": None,
"interpolation": None,
"image_mean": None,
"image_std": None,
}
)
cfg_tf = OmegaConf.merge(
OmegaConf.create(default_tf), cfg.training.image_transforms
)
cfg_tf = cfg.training.image_transforms
image_transforms = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,
@@ -106,18 +88,12 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg_tf.random_order,
image_size=(cfg_tf.image_size.height, cfg_tf.image_size.width)
if cfg_tf.image_size
else None,
interpolation=cfg_tf.interpolation,
image_mean=cfg_tf.image_mean,
image_std=cfg_tf.image_std,
)
if isinstance(cfg.dataset_repo_id, str):
# TODO (aliberts): add 'episodes' arg from config after removing hydra
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
video_backend=cfg.video_backend,
@@ -125,6 +101,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
video_backend=cfg.video_backend,
@@ -135,8 +112,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
for stats_type, listconfig in stats_dict.items():
# example of stats_type: min, max, mean, std
stats = OmegaConf.to_container(listconfig, resolve=True)
dataset.meta.stats[key][stats_type] = torch.tensor(
stats, dtype=torch.float32
)
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return dataset

View File

@@ -1,166 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import queue
import threading
from pathlib import Path
import numpy as np
import PIL.Image
import torch
def safe_stop_image_writer(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
dataset = kwargs.get("dataset")
image_writer = getattr(dataset, "image_writer", None) if dataset else None
if image_writer is not None:
print("Waiting for image writer to terminate...")
image_writer.stop()
raise e
return wrapper
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
# Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0)
if image_array.dtype != np.uint8:
# Assume the image is in [0, 1] range for floating-point data
image_array = np.clip(image_array, 0, 1)
image_array = (image_array * 255).astype(np.uint8)
return PIL.Image.fromarray(image_array)
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
try:
if isinstance(image, np.ndarray):
img = image_array_to_image(image)
elif isinstance(image, PIL.Image.Image):
img = image
else:
raise TypeError(f"Unsupported image type: {type(image)}")
img.save(fpath)
except Exception as e:
print(f"Error writing image {fpath}: {e}")
def worker_thread_loop(queue: queue.Queue):
while True:
item = queue.get()
if item is None:
queue.task_done()
break
image_array, fpath = item
write_image(image_array, fpath)
queue.task_done()
def worker_process(queue: queue.Queue, num_threads: int):
threads = []
for _ in range(num_threads):
t = threading.Thread(target=worker_thread_loop, args=(queue,))
t.daemon = True
t.start()
threads.append(t)
for t in threads:
t.join()
class AsyncImageWriter:
"""
This class abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data
at a high frame rate.
When `num_processes=0`, it creates a threads pool of size `num_threads`.
When `num_processes>0`, it creates processes pool of size `num_processes`, where each subprocess starts
their own threads pool of size `num_threads`.
The optimal number of processes and threads depends on your computer capabilities.
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
"""
def __init__(self, num_processes: int = 0, num_threads: int = 1):
self.num_processes = num_processes
self.num_threads = num_threads
self.queue = None
self.threads = []
self.processes = []
self._stopped = False
if num_threads <= 0 and num_processes <= 0:
raise ValueError(
"Number of threads and processes must be greater than zero."
)
if self.num_processes == 0:
# Use threading
self.queue = queue.Queue()
for _ in range(self.num_threads):
t = threading.Thread(target=worker_thread_loop, args=(self.queue,))
t.daemon = True
t.start()
self.threads.append(t)
else:
# Use multiprocessing
self.queue = multiprocessing.JoinableQueue()
for _ in range(self.num_processes):
p = multiprocessing.Process(
target=worker_process, args=(self.queue, self.num_threads)
)
p.daemon = True
p.start()
self.processes.append(p)
def save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
):
if isinstance(image, torch.Tensor):
# Convert tensor to numpy array to minimize main process time
image = image.cpu().numpy()
self.queue.put((image, fpath))
def wait_until_done(self):
self.queue.join()
def stop(self):
if self._stopped:
return
if self.num_processes == 0:
for _ in self.threads:
self.queue.put(None)
for t in self.threads:
t.join()
else:
num_nones = self.num_processes * self.num_threads
for _ in range(num_nones):
self.queue.put(None)
for p in self.processes:
p.join()
if p.is_alive():
p.terminate()
self.queue.close()
self.queue.join_thread()
self._stopped = True

File diff suppressed because it is too large Load Diff

View File

@@ -131,9 +131,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
else:
self._delta_timestamps = None
def _make_data_spec(
self, data_spec: dict[str, Any], buffer_capacity: int
) -> dict[str, dict[str, Any]]:
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
"""Makes the data spec for np.memmap."""
if any(k.startswith("_") for k in data_spec):
raise ValueError(
@@ -156,32 +154,14 @@ class OnlineBuffer(torch.utils.data.Dataset):
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
# with real data rather than the dummy initialization.
OnlineBuffer.OCCUPANCY_MASK_KEY: {
"dtype": np.dtype("?"),
"shape": (buffer_capacity,),
},
OnlineBuffer.INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.FRAME_INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.EPISODE_INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.TIMESTAMP_KEY: {
"dtype": np.dtype("float64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
}
for k, v in data_spec.items():
complete_data_spec[k] = {
"dtype": v["dtype"],
"shape": (buffer_capacity, *v["shape"]),
}
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
return complete_data_spec
def add_data(self, data: dict[str, np.ndarray]):
@@ -207,10 +187,8 @@ class OnlineBuffer(torch.utils.data.Dataset):
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
# Shift the incoming indices if necessary.
if self.num_frames > 0:
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][
next_index - 1
]
if self.num_samples > 0:
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
@@ -245,19 +223,15 @@ class OnlineBuffer(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(
np.unique(
self._data[OnlineBuffer.EPISODE_INDEX_KEY][
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]
]
)
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
)
@property
def num_frames(self) -> int:
def num_samples(self) -> int:
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
def __len__(self):
return self.num_frames
return self.num_samples
def _item_to_tensors(self, item: dict) -> dict:
item_ = {}
@@ -287,9 +261,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
)
)[0]
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][
episode_data_indices
]
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
for data_key in self.delta_timestamps:
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
@@ -306,8 +278,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
# Check violated query timestamps are all outside the episode range.
assert (
(query_ts[is_pad] < episode_timestamps[0])
| (episode_timestamps[-1] < query_ts[is_pad])
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
") inside the episode range."
@@ -322,9 +293,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
def get_data_by_key(self, key: str) -> torch.Tensor:
"""Returns all data for a given data key as a Tensor."""
return torch.from_numpy(
self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]
)
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
def compute_sampler_weights(
@@ -355,19 +324,13 @@ def compute_sampler_weights(
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
included here to avoid adding complexity.
"""
if len(offline_dataset) == 0 and (
online_dataset is None or len(online_dataset) == 0
):
raise ValueError(
"At least one of `offline_dataset` or `online_dataset` should be contain data."
)
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
if (online_dataset is None) ^ (online_sampling_ratio is None):
raise ValueError(
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
)
offline_sampling_ratio = (
0 if online_sampling_ratio is None else 1 - online_sampling_ratio
)
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
weights = []

View File

@@ -0,0 +1,468 @@
"""Functions to create an empty dataset, and populate it with frames."""
# TODO(rcadene, aliberts): to adapt as class methods of next version of LeRobotDataset
import concurrent
import json
import logging
import multiprocessing
import shutil
from pathlib import Path
import torch
import tqdm
from PIL import Image
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.common.utils.utils import log_say
from lerobot.scripts.push_dataset_to_hub import (
push_dataset_card_to_hub,
push_meta_data_to_hub,
push_videos_to_hub,
save_meta_data,
)
########################################################################################
# Asynchrounous saving of images on disk
########################################################################################
def safe_stop_image_writer(func):
# TODO(aliberts): Allow to pass custom exceptions
# (e.g. ThreadServiceExit, KeyboardInterrupt, SystemExit, UnpluggedError, DynamixelCommError)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
image_writer = kwargs.get("dataset", {}).get("image_writer")
if image_writer is not None:
print("Waiting for image writer to terminate...")
stop_image_writer(image_writer, timeout=20)
raise e
return wrapper
def save_image(img_tensor, key, frame_index, episode_index, videos_dir: str):
img = Image.fromarray(img_tensor.numpy())
path = Path(videos_dir) / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
path.parent.mkdir(parents=True, exist_ok=True)
img.save(str(path), quality=100)
def loop_to_save_images_in_threads(image_queue, num_threads):
if num_threads < 1:
raise NotImplementedError(f"Only `num_threads>=1` is supported for now, but {num_threads=} given.")
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
while True:
# Blocks until a frame is available
frame_data = image_queue.get()
# As usually done, exit loop when receiving None to stop the worker
if frame_data is None:
break
image, key, frame_index, episode_index, videos_dir = frame_data
futures.append(executor.submit(save_image, image, key, frame_index, episode_index, videos_dir))
# Before exiting function, wait for all threads to complete
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
concurrent.futures.wait(futures)
progress_bar.update(len(futures))
def start_image_writer_processes(image_queue, num_processes, num_threads_per_process):
if num_processes < 1:
raise ValueError(f"Only `num_processes>=1` is supported, but {num_processes=} given.")
if num_threads_per_process < 1:
raise NotImplementedError(
"Only `num_threads_per_process>=1` is supported for now, but {num_threads_per_process=} given."
)
processes = []
for _ in range(num_processes):
process = multiprocessing.Process(
target=loop_to_save_images_in_threads,
args=(image_queue, num_threads_per_process),
)
process.start()
processes.append(process)
return processes
def stop_processes(processes, queue, timeout):
# Send None to each process to signal them to stop
for _ in processes:
queue.put(None)
# Wait maximum 20 seconds for all processes to terminate
for process in processes:
process.join(timeout=timeout)
# If not terminated after 20 seconds, force termination
if process.is_alive():
process.terminate()
# Close the queue, no more items can be put in the queue
queue.close()
# Ensure all background queue threads have finished
queue.join_thread()
def start_image_writer(num_processes, num_threads):
"""This function abstract away the initialisation of processes or/and threads to
save images on disk asynchrounously, which is critical to control a robot and record data
at a high frame rate.
When `num_processes=0`, it returns a dictionary containing a threads pool of size `num_threads`.
When `num_processes>0`, it returns a dictionary containing a processes pool of size `num_processes`,
where each subprocess starts their own threads pool of size `num_threads`.
The optimal number of processes and threads depends on your computer capabilities.
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
"""
image_writer = {}
if num_processes == 0:
futures = []
threads_pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)
image_writer["threads_pool"], image_writer["futures"] = threads_pool, futures
else:
# TODO(rcadene): When using num_processes>1, `multiprocessing.Manager().Queue()`
# might be better than `multiprocessing.Queue()`. Source: https://www.geeksforgeeks.org/python-multiprocessing-queue-vs-multiprocessing-manager-queue
image_queue = multiprocessing.Queue()
processes_pool = start_image_writer_processes(
image_queue, num_processes=num_processes, num_threads_per_process=num_threads
)
image_writer["processes_pool"], image_writer["image_queue"] = processes_pool, image_queue
return image_writer
def async_save_image(image_writer, image, key, frame_index, episode_index, videos_dir):
"""This function abstract away the saving of an image on disk asynchrounously. It uses a dictionary
called image writer which contains either a pool of processes or a pool of threads.
"""
if "threads_pool" in image_writer:
threads_pool, futures = image_writer["threads_pool"], image_writer["futures"]
futures.append(threads_pool.submit(save_image, image, key, frame_index, episode_index, videos_dir))
else:
image_queue = image_writer["image_queue"]
image_queue.put((image, key, frame_index, episode_index, videos_dir))
def stop_image_writer(image_writer, timeout):
if "threads_pool" in image_writer:
futures = image_writer["futures"]
# Before exiting function, wait for all threads to complete
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
concurrent.futures.wait(futures, timeout=timeout)
progress_bar.update(len(futures))
else:
processes_pool, image_queue = image_writer["processes_pool"], image_writer["image_queue"]
stop_processes(processes_pool, image_queue, timeout=timeout)
########################################################################################
# Functions to initialize, resume and populate a dataset
########################################################################################
def init_dataset(
repo_id,
root,
force_override,
fps,
video,
write_images,
num_image_writer_processes,
num_image_writer_threads,
):
local_dir = Path(root) / repo_id
if local_dir.exists() and force_override:
shutil.rmtree(local_dir)
episodes_dir = local_dir / "episodes"
episodes_dir.mkdir(parents=True, exist_ok=True)
videos_dir = local_dir / "videos"
videos_dir.mkdir(parents=True, exist_ok=True)
# Logic to resume data recording
rec_info_path = episodes_dir / "data_recording_info.json"
if rec_info_path.exists():
with open(rec_info_path) as f:
rec_info = json.load(f)
num_episodes = rec_info["last_episode_index"] + 1
else:
num_episodes = 0
dataset = {
"repo_id": repo_id,
"local_dir": local_dir,
"videos_dir": videos_dir,
"episodes_dir": episodes_dir,
"fps": fps,
"video": video,
"rec_info_path": rec_info_path,
"num_episodes": num_episodes,
}
if write_images:
# Initialize processes or/and threads dedicated to save images on disk asynchronously,
# which is critical to control a robot and record data at a high frame rate.
image_writer = start_image_writer(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads,
)
dataset["image_writer"] = image_writer
return dataset
def add_frame(dataset, observation, action):
if "current_episode" not in dataset:
# initialize episode dictionary
ep_dict = {}
for key in observation:
if key not in ep_dict:
ep_dict[key] = []
for key in action:
if key not in ep_dict:
ep_dict[key] = []
ep_dict["episode_index"] = []
ep_dict["frame_index"] = []
ep_dict["timestamp"] = []
ep_dict["next.done"] = []
dataset["current_episode"] = ep_dict
dataset["current_frame_index"] = 0
ep_dict = dataset["current_episode"]
episode_index = dataset["num_episodes"]
frame_index = dataset["current_frame_index"]
videos_dir = dataset["videos_dir"]
video = dataset["video"]
fps = dataset["fps"]
ep_dict["episode_index"].append(episode_index)
ep_dict["frame_index"].append(frame_index)
ep_dict["timestamp"].append(frame_index / fps)
ep_dict["next.done"].append(False)
img_keys = [key for key in observation if "image" in key]
non_img_keys = [key for key in observation if "image" not in key]
# Save all observed modalities except images
for key in non_img_keys:
ep_dict[key].append(observation[key])
# Save actions
for key in action:
ep_dict[key].append(action[key])
if "image_writer" not in dataset:
dataset["current_frame_index"] += 1
return
# Save images
image_writer = dataset["image_writer"]
for key in img_keys:
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
async_save_image(
image_writer,
image=observation[key],
key=key,
frame_index=frame_index,
episode_index=episode_index,
videos_dir=str(videos_dir),
)
if video:
fname = f"{key}_episode_{episode_index:06d}.mp4"
frame_info = {"path": f"videos/{fname}", "timestamp": frame_index / fps}
else:
frame_info = str(imgs_dir / f"frame_{frame_index:06d}.png")
ep_dict[key].append(frame_info)
dataset["current_frame_index"] += 1
def delete_current_episode(dataset):
del dataset["current_episode"]
del dataset["current_frame_index"]
# delete temporary images
episode_index = dataset["num_episodes"]
videos_dir = dataset["videos_dir"]
for tmp_imgs_dir in videos_dir.glob(f"*_episode_{episode_index:06d}"):
shutil.rmtree(tmp_imgs_dir)
def save_current_episode(dataset):
episode_index = dataset["num_episodes"]
ep_dict = dataset["current_episode"]
episodes_dir = dataset["episodes_dir"]
rec_info_path = dataset["rec_info_path"]
ep_dict["next.done"][-1] = True
for key in ep_dict:
if "observation" in key and "image" not in key:
ep_dict[key] = torch.stack(ep_dict[key])
ep_dict["action"] = torch.stack(ep_dict["action"])
ep_dict["episode_index"] = torch.tensor(ep_dict["episode_index"])
ep_dict["frame_index"] = torch.tensor(ep_dict["frame_index"])
ep_dict["timestamp"] = torch.tensor(ep_dict["timestamp"])
ep_dict["next.done"] = torch.tensor(ep_dict["next.done"])
ep_path = episodes_dir / f"episode_{episode_index}.pth"
torch.save(ep_dict, ep_path)
rec_info = {
"last_episode_index": episode_index,
}
with open(rec_info_path, "w") as f:
json.dump(rec_info, f)
# force re-initialization of episode dictionnary during add_frame
del dataset["current_episode"]
dataset["num_episodes"] += 1
def encode_videos(dataset, image_keys, play_sounds):
log_say("Encoding videos", play_sounds)
num_episodes = dataset["num_episodes"]
videos_dir = dataset["videos_dir"]
local_dir = dataset["local_dir"]
fps = dataset["fps"]
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in tqdm.tqdm(range(num_episodes)):
for key in image_keys:
# key = f"observation.images.{name}"
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
if video_path.exists():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
shutil.rmtree(tmp_imgs_dir)
def from_dataset_to_lerobot_dataset(dataset, play_sounds):
log_say("Consolidate episodes", play_sounds)
num_episodes = dataset["num_episodes"]
episodes_dir = dataset["episodes_dir"]
videos_dir = dataset["videos_dir"]
video = dataset["video"]
fps = dataset["fps"]
repo_id = dataset["repo_id"]
ep_dicts = []
for episode_index in tqdm.tqdm(range(num_episodes)):
ep_path = episodes_dir / f"episode_{episode_index}.pth"
ep_dict = torch.load(ep_path)
ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts)
if video:
image_keys = [key for key in data_dict if "image" in key]
encode_videos(dataset, image_keys, play_sounds)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps,
"video": video,
}
if video:
info["encoding"] = get_default_encoding()
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
hf_dataset=hf_dataset,
episode_data_index=episode_data_index,
info=info,
videos_dir=videos_dir,
)
return lerobot_dataset
def save_lerobot_dataset_on_disk(lerobot_dataset):
hf_dataset = lerobot_dataset.hf_dataset
info = lerobot_dataset.info
stats = lerobot_dataset.stats
episode_data_index = lerobot_dataset.episode_data_index
local_dir = lerobot_dataset.videos_dir.parent
meta_data_dir = local_dir / "meta_data"
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(local_dir / "train"))
save_meta_data(info, stats, episode_data_index, meta_data_dir)
def push_lerobot_dataset_to_hub(lerobot_dataset, tags):
hf_dataset = lerobot_dataset.hf_dataset
local_dir = lerobot_dataset.videos_dir.parent
videos_dir = lerobot_dataset.videos_dir
repo_id = lerobot_dataset.repo_id
video = lerobot_dataset.video
meta_data_dir = local_dir / "meta_data"
if not (local_dir / "train").exists():
raise ValueError(
"You need to run `save_lerobot_dataset_on_disk(lerobot_dataset)` before pushing to the hub."
)
hf_dataset.push_to_hub(repo_id, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
if video:
push_videos_to_hub(repo_id, videos_dir, revision="main")
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds):
if "image_writer" in dataset:
logging.info("Waiting for image writer to terminate...")
image_writer = dataset["image_writer"]
stop_image_writer(image_writer, timeout=20)
lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds)
if run_compute_stats:
log_say("Computing dataset statistics", play_sounds)
lerobot_dataset.stats = compute_stats(lerobot_dataset)
else:
logging.info("Skipping computation of the dataset statistics")
lerobot_dataset.stats = {}
save_lerobot_dataset_on_disk(lerobot_dataset)
if push_to_hub:
push_lerobot_dataset_to_hub(lerobot_dataset, tags)
return lerobot_dataset

View File

@@ -37,16 +37,10 @@ def check_chunks_compatible(chunks: tuple, shape: tuple):
assert c > 0
def rechunk_recompress_array(
group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"
):
def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"):
old_arr = group[name]
if chunks is None:
chunks = (
(chunk_length,) + old_arr.chunks[1:]
if chunk_length is not None
else old_arr.chunks
)
chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks
check_chunks_compatible(chunks, old_arr.shape)
if compressor is None:
@@ -88,18 +82,13 @@ def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=No
for i in range(len(shape) - 1):
this_chunk_bytes = itemsize * np.prod(rshape[:i])
next_chunk_bytes = itemsize * np.prod(rshape[: i + 1])
if (
this_chunk_bytes <= target_chunk_bytes
and next_chunk_bytes > target_chunk_bytes
):
if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes:
split_idx = i
rchunks = rshape[:split_idx]
item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
this_max_chunk_length = rshape[split_idx]
next_chunk_length = min(
this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes)
)
next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes))
rchunks.append(next_chunk_length)
len_diff = len(shape) - len(rchunks)
rchunks.extend([1] * len_diff)
@@ -135,13 +124,7 @@ class ReplayBuffer:
root.require_group("data", overwrite=False)
meta = root.require_group("meta", overwrite=False)
if "episode_ends" not in meta:
meta.zeros(
"episode_ends",
shape=(0,),
dtype=np.int64,
compressor=None,
overwrite=False,
)
meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False)
return cls(root=root)
@classmethod
@@ -210,11 +193,7 @@ class ReplayBuffer:
root = zarr.group(store=store)
# copy without recompression
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
source=src_store,
dest=store,
source_path="/meta",
dest_path="/meta",
if_exists=if_exists,
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
)
data_group = root.create_group("data", overwrite=True)
if keys is None:
@@ -222,9 +201,7 @@ class ReplayBuffer:
for key in keys:
value = src_root["data"][key]
cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value)
cpr = cls._resolve_array_compressor(
compressors=compressors, key=key, array=value
)
cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value)
if cks == value.chunks and cpr == value.compressor:
# copy without recompression
this_path = "/data/" + key
@@ -309,17 +286,13 @@ class ReplayBuffer:
meta_group = root.create_group("meta", overwrite=True)
# save meta, no chunking
for key, value in self.root["meta"].items():
_ = meta_group.array(
name=key, data=value, shape=value.shape, chunks=value.shape
)
_ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape)
# save data, chunk
data_group = root.create_group("data", overwrite=True)
for key, value in self.root["data"].items():
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
cpr = self._resolve_array_compressor(
compressors=compressors, key=key, array=value
)
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
if isinstance(value, zarr.Array):
if cks == value.chunks and cpr == value.compressor:
# copy without recompression
@@ -366,19 +339,13 @@ class ReplayBuffer:
@staticmethod
def resolve_compressor(compressor="default"):
if compressor == "default":
compressor = numcodecs.Blosc(
cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE
)
compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE)
elif compressor == "disk":
compressor = numcodecs.Blosc(
"zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE
)
compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE)
return compressor
@classmethod
def _resolve_array_compressor(
cls, compressors: dict | str | numcodecs.abc.Codec, key, array
):
def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array):
# allows compressor to be explicitly set to None
cpr = "nil"
if isinstance(compressors, dict):
@@ -437,11 +404,7 @@ class ReplayBuffer:
if self.backend == "zarr":
for key, value in np_data.items():
_ = meta_group.array(
name=key,
data=value,
shape=value.shape,
chunks=value.shape,
overwrite=True,
name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True
)
else:
meta_group.update(np_data)
@@ -551,18 +514,10 @@ class ReplayBuffer:
# create array
if key not in self.data:
if is_zarr:
cks = self._resolve_array_chunks(
chunks=chunks, key=key, array=value
)
cpr = self._resolve_array_compressor(
compressors=compressors, key=key, array=value
)
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
arr = self.data.zeros(
name=key,
shape=new_shape,
chunks=cks,
dtype=value.dtype,
compressor=cpr,
name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr
)
else:
# copy data to prevent modify
@@ -589,9 +544,7 @@ class ReplayBuffer:
# rechunk
if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]:
rechunk_recompress_array(
self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5)
)
rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5))
def drop_episode(self):
is_zarr = self.backend == "zarr"

View File

@@ -38,9 +38,7 @@ import argparse
from pathlib import Path
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub._download_raw import (
AVAILABLE_RAW_REPO_IDS,
)
from lerobot.common.datasets.push_dataset_to_hub._download_raw import AVAILABLE_RAW_REPO_IDS
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
@@ -75,9 +73,7 @@ def encode_datasets(
check_repo_id(raw_repo_id)
dataset_repo_id_push = get_push_repo_id_from_raw(raw_repo_id, push_repo)
dataset_raw_dir = raw_dir / raw_repo_id
dataset_dir = (
local_dir / dataset_repo_id_push if local_dir is not None else None
)
dataset_dir = local_dir / dataset_repo_id_push if local_dir is not None else None
encoding = {
"vcodec": vcodec,
"pix_fmt": pix_fmt,

View File

@@ -133,9 +133,7 @@ class Jpeg2k(Codec):
)
def decode(self, buf, out=None):
return imagecodecs.jpeg2k_decode(
buf, verbose=self.verbose, numthreads=self.numthreads, out=out
)
return imagecodecs.jpeg2k_decode(buf, verbose=self.verbose, numthreads=self.numthreads, out=out)
class JpegXl(Codec):

View File

@@ -30,12 +30,12 @@ from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
get_default_encoding,
save_images_concurrently,
)
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
@@ -44,9 +44,7 @@ from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def get_cameras(hdf5_data):
# ignore depth channel, not currently handled
# TODO(rcadene): add depth
rgb_cameras = [
key for key in hdf5_data["/observations/images"].keys() if "depth" not in key
] # noqa: SIM118
rgb_cameras = [key for key in hdf5_data["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
return rgb_cameras
@@ -75,9 +73,7 @@ def check_format(raw_dir) -> bool:
else:
assert data[f"/observations/images/{camera}"].ndim == 4
b, h, w, c = data[f"/observations/images/{camera}"].shape
assert (
c < h and c < w
), f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
def load_from_raw(
@@ -138,17 +134,14 @@ def load_from_raw(
# encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = videos_dir / fname
encode_video_frames(
tmp_imgs_dir, video_path, fps, **(encoding or {})
)
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps}
for i in range(num_frames)
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
@@ -188,18 +181,15 @@ def to_hf_dataset(data_dict, video) -> Dataset:
features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1],
feature=Value(dtype="float32", id=None),
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1],
feature=Value(dtype="float32", id=None),
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1],
feature=Value(dtype="float32", id=None),
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)

View File

@@ -24,11 +24,8 @@ from datasets import Dataset, Features, Image, Value
from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
)
from lerobot.common.datasets.utils import hf_transform_to_torch
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch
from lerobot.common.datasets.video_utils import VideoFrame

View File

@@ -26,10 +26,8 @@ import torch
from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
)
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame
@@ -44,19 +42,11 @@ def check_format(raw_dir) -> bool:
return True
def load_from_raw(
raw_dir: Path,
videos_dir: Path,
fps: int,
video: bool,
episodes: list[int] | None = None,
):
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
# Load data stream that will be used as reference for the timestamps synchronization
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
if len(reference_files) == 0:
raise ValueError(
f"Missing reference files for camera, starting with in '{raw_dir}'"
)
raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
# select first camera in alphanumeric order
reference_key = sorted(reference_files)[0].stem
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
@@ -117,9 +107,7 @@ def load_from_raw(
df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
# each episode starts with timestamp 0 to match the ones from the video
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(
lambda x: x - x.iloc[0]
)
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
del df["timestamp_utc"]
@@ -132,9 +120,7 @@ def load_from_raw(
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
expected_ep_ids = list(range(df["episode_index"].max() + 1))
if ep_ids != expected_ep_ids:
raise ValueError(
f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}"
)
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
# Create symlink to raw videos directory (that needs to be absolute not relative)
videos_dir.parent.mkdir(parents=True, exist_ok=True)
@@ -166,9 +152,7 @@ def load_from_raw(
data_dict[key] = torch.from_numpy(df[key].values)
# is vector
elif df[key].iloc[0].shape[0] > 1:
data_dict[key] = torch.stack(
[torch.from_numpy(x.copy()) for x in df[key].values]
)
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
else:
raise ValueError(key)
@@ -186,18 +170,15 @@ def to_hf_dataset(data_dict, video) -> Dataset:
features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1],
feature=Value(dtype="float32", id=None),
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1],
feature=Value(dtype="float32", id=None),
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1],
feature=Value(dtype="float32", id=None),
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)

View File

@@ -0,0 +1,639 @@
OPENX_DATASET_CONFIGS:
fractal20220817_data:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- base_pose_tool_reached
- gripper_closed
fps: 3
kuka:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- clip_function_input/base_pose_tool_reached
- gripper_closed
fps: 10
bridge_openx:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- EEF_state
- gripper_state
fps: 5
taco_play:
image_obs_keys:
- rgb_static
- rgb_gripper
depth_obs_keys:
- depth_static
- depth_gripper
state_obs_keys:
- state_eef
- state_gripper
fps: 15
jaco_play:
image_obs_keys:
- image
- image_wrist
depth_obs_keys:
- null
state_obs_keys:
- state_eef
- state_gripper
fps: 10
berkeley_cable_routing:
image_obs_keys:
- image
- top_image
- wrist45_image
- wrist225_image
depth_obs_keys:
- null
state_obs_keys:
- robot_state
fps: 10
roboturk:
image_obs_keys:
- front_rgb
depth_obs_keys:
- null
state_obs_keys:
- null
fps: 10
nyu_door_opening_surprising_effectiveness:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- null
fps: 3
viola:
image_obs_keys:
- agentview_rgb
- eye_in_hand_rgb
depth_obs_keys:
- null
state_obs_keys:
- joint_states
- gripper_states
fps: 20
berkeley_autolab_ur5:
image_obs_keys:
- image
- hand_image
depth_obs_keys:
- image_with_depth
state_obs_keys:
- state
fps: 5
toto:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 30
language_table:
image_obs_keys:
- rgb
depth_obs_keys:
- null
state_obs_keys:
- effector_translation
fps: 10
columbia_cairlab_pusht_real:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- robot_state
fps: 10
stanford_kuka_multimodal_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- depth_image
state_obs_keys:
- ee_position
- ee_orientation
fps: 20
nyu_rot_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 3
io_ai_tech:
image_obs_keys:
- image
- image_fisheye
- image_left_side
- image_right_side
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 3
stanford_hydra_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 10
austin_buds_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 20
nyu_franka_play_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- image_additional_view
depth_obs_keys:
- depth
- depth_additional_view
state_obs_keys:
- eef_state
fps: 3
maniskill_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- depth
- wrist_depth
state_obs_keys:
- tcp_pose
- gripper_state
fps: 20
furniture_bench_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
cmu_franka_exploration_dataset_converted_externally_to_rlds:
image_obs_keys:
- highres_image
depth_obs_keys:
- null
state_obs_keys:
- null
fps: 10
ucsd_kitchen_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- joint_state
fps: 2
ucsd_pick_and_place_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 3
spoc:
image_obs_keys:
- image
- image_manipulation
depth_obs_keys:
- null
state_obs_keys:
- null
fps: 3
austin_sailor_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 20
austin_sirius_dataset_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 20
bc_z:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- present/xyz
- present/axis_angle
- present/sensed_close
fps: 10
utokyo_pr2_opening_fridge_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 10
utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 10
utokyo_xarm_pick_and_place_converted_externally_to_rlds:
image_obs_keys:
- image
- image2
- hand_image
depth_obs_keys:
- null
state_obs_keys:
- end_effector_pose
fps: 10
utokyo_xarm_bimanual_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- pose_r
fps: 10
robo_net:
image_obs_keys:
- image
- image1
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 1
robo_set:
image_obs_keys:
- image_left
- image_right
- image_wrist
depth_obs_keys:
- null
state_obs_keys:
- state
- state_velocity
fps: 5
berkeley_mvp_converted_externally_to_rlds:
image_obs_keys:
- hand_image
depth_obs_keys:
- null
state_obs_keys:
- gripper
- pose
- joint_pos
fps: 5
berkeley_rpt_converted_externally_to_rlds:
image_obs_keys:
- hand_image
depth_obs_keys:
- null
state_obs_keys:
- joint_pos
- gripper
fps: 30
kaist_nonprehensile_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
stanford_mask_vit_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
tokyo_u_lsmo_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 10
dlr_sara_pour_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
dlr_sara_grid_clamp_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
dlr_edan_shared_control_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 5
asu_table_top_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 12.5
stanford_robocook_converted_externally_to_rlds:
image_obs_keys:
- image_1
- image_2
depth_obs_keys:
- depth_1
- depth_2
state_obs_keys:
- eef_state
- gripper_state
fps: 5
imperialcollege_sawyer_wrist_cam:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
iamlab_cmu_pickup_insert_converted_externally_to_rlds:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- joint_state
- gripper_state
fps: 20
uiuc_d3field:
image_obs_keys:
- image_1
- image_2
depth_obs_keys:
- depth_1
- depth_2
state_obs_keys:
- null
fps: 1
utaustin_mutex:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 20
berkeley_fanuc_manipulation:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- joint_state
- gripper_state
fps: 10
cmu_playing_with_food:
image_obs_keys:
- image
- finger_vision_1
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 10
cmu_play_fusion:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 5
cmu_stretch:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- eef_state
- gripper_state
fps: 10
berkeley_gnm_recon:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
- position
- yaw
fps: 3
berkeley_gnm_cory_hall:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
- position
- yaw
fps: 5
berkeley_gnm_sac_son:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- state
- position
- yaw
fps: 10
droid:
image_obs_keys:
- exterior_image_1_left
- exterior_image_2_left
- wrist_image_left
depth_obs_keys:
- null
state_obs_keys:
- proprio
fps: 15
droid_100:
image_obs_keys:
- exterior_image_1_left
- exterior_image_2_left
- wrist_image_left
depth_obs_keys:
- null
state_obs_keys:
- proprio
fps: 15
fmb:
image_obs_keys:
- image_side_1
- image_side_2
- image_wrist_1
- image_wrist_2
depth_obs_keys:
- image_side_1_depth
- image_side_2_depth
- image_wrist_1_depth
- image_wrist_2_depth
state_obs_keys:
- proprio
fps: 10
dobbe:
image_obs_keys:
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- proprio
fps: 3.75
usc_cloth_sim_converted_externally_to_rlds:
image_obs_keys:
- image
depth_obs_keys:
- null
state_obs_keys:
- null
fps: 10
plex_robosuite:
image_obs_keys:
- image
- wrist_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 20
conq_hose_manipulation:
image_obs_keys:
- frontleft_fisheye_image
- frontright_fisheye_image
- hand_color_image
depth_obs_keys:
- null
state_obs_keys:
- state
fps: 30

View File

@@ -0,0 +1,106 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the Licens e.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
NOTE(YL): Adapted from:
Octo: https://github.com/octo-models/octo/blob/main/octo/data/utils/data_utils.py
data_utils.py
Additional utils for data processing.
"""
from typing import Any, Dict, List
import tensorflow as tf
def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
"""
Converts gripper actions from continuous to binary values (0 and 1).
We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it
transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate
values based on the state that is reached _after_ those intermediate values.
In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that
chunk of intermediate values as the last action in the trajectory.
The `scan_fn` implements the following logic:
new_actions = np.empty_like(actions)
carry = actions[-1]
for i in reversed(range(actions.shape[0])):
if in_between_mask[i]:
carry = carry
else:
carry = float(open_mask[i])
new_actions[i] = carry
"""
open_mask, closed_mask = actions > 0.95, actions < 0.05
in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask))
is_open_float = tf.cast(open_mask, tf.float32)
def scan_fn(carry, i):
return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i])
return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True)
def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
return 1 - actions
def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
"""
Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open).
Assumes that the first relative gripper is not redundant (i.e. close when already closed)!
"""
# Note =>> -1 for closing, 1 for opening, 0 for no change
opening_mask, closing_mask = actions < -0.1, actions > 0.1
thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0))
def scan_fn(carry, i):
return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i])
# If no relative grasp, assumes open for whole trajectory
start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)]
start = tf.cond(start == 0, lambda: 1, lambda: start)
# Note =>> -1 for closed, 1 for open
new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start)
new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5
return new_actions
# === Bridge-V2 =>> Dataset-Specific Transform ===
def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6]
traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj)
traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1)
return traj_truncated
# === RLDS Dataset Initialization Utilities ===
def pprint_data_mixture(dataset_kwargs_list: List[Dict[str, Any]], dataset_weights: List[int]) -> None:
print("\n######################################################################################")
print(f"# Loading the following {len(dataset_kwargs_list)} datasets (incl. sampling weight):{'': >24} #")
for dataset_kwargs, weight in zip(dataset_kwargs_list, dataset_weights, strict=False):
pad = 80 - len(dataset_kwargs["name"])
print(f"# {dataset_kwargs['name']}: {weight:=>{pad}f} #")
print("######################################################################################\n")

View File

@@ -0,0 +1,200 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
NOTE(YL): Adapted from:
OpenVLA: https://github.com/openvla/openvla
Episode transforms for DROID dataset.
"""
from typing import Any, Dict
import tensorflow as tf
import tensorflow_graphics.geometry.transformation as tfg
def rmat_to_euler(rot_mat):
return tfg.euler.from_rotation_matrix(rot_mat)
def euler_to_rmat(euler):
return tfg.rotation_matrix_3d.from_euler(euler)
def invert_rmat(rot_mat):
return tfg.rotation_matrix_3d.inverse(rot_mat)
def rotmat_to_rot6d(mat):
"""
Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix).
Args:
mat: rotation matrix
Returns: 6d vector (first two rows of rotation matrix)
"""
r6 = mat[..., :2, :]
r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :]
r6_flat = tf.concat([r6_0, r6_1], axis=-1)
return r6_flat
def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame):
"""
Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame.
Args:
velocity: 6d velocity action (3 x translation, 3 x rotation)
wrist_in_robot_frame: 6d pose of the end-effector in robot base frame
Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6)
"""
r_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6])
r_frame_inv = invert_rmat(r_frame)
# world to wrist: dT_pi = R^-1 dT_rbt
vel_t = (r_frame_inv @ velocity[:, :3][..., None])[..., 0]
# world to wrist: dR_pi = R^-1 dR_rbt R
dr_ = euler_to_rmat(velocity[:, 3:6])
dr_ = r_frame_inv @ (dr_ @ r_frame)
dr_r6 = rotmat_to_rot6d(dr_)
return tf.concat([vel_t, dr_r6], axis=-1)
def rand_swap_exterior_images(img1, img2):
"""
Randomly swaps the two exterior images (for training with single exterior input).
"""
return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1))
def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
DROID dataset transformation for actions expressed in *base* frame of the robot.
"""
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
trajectory["action"] = tf.concat(
(
dt,
dr_,
1 - trajectory["action_dict"]["gripper_position"],
),
axis=-1,
)
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
rand_swap_exterior_images(
trajectory["observation"]["exterior_image_1_left"],
trajectory["observation"]["exterior_image_2_left"],
)
)
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["cartesian_position"],
trajectory["observation"]["gripper_position"],
),
axis=-1,
)
return trajectory
def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
DROID dataset transformation for actions expressed in *wrist* frame of the robot.
"""
wrist_act = velocity_act_to_wrist_frame(
trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"]
)
trajectory["action"] = tf.concat(
(
wrist_act,
trajectory["action_dict"]["gripper_position"],
),
axis=-1,
)
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
rand_swap_exterior_images(
trajectory["observation"]["exterior_image_1_left"],
trajectory["observation"]["exterior_image_2_left"],
)
)
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["cartesian_position"],
trajectory["observation"]["gripper_position"],
),
axis=-1,
)
return trajectory
def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
DROID dataset transformation for actions expressed in *base* frame of the robot.
"""
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
trajectory["action"] = tf.concat(
(
dt,
dr_,
1 - trajectory["action_dict"]["gripper_position"],
),
axis=-1,
)
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["cartesian_position"],
trajectory["observation"]["gripper_position"],
),
axis=-1,
)
return trajectory
def zero_action_filter(traj: Dict) -> bool:
"""
Filters transitions whose actions are all-0 (only relative actions, no gripper action).
Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
"""
droid_q01 = tf.convert_to_tensor(
[
-0.7776297926902771,
-0.5803514122962952,
-0.5795090794563293,
-0.6464047729969025,
-0.7041108310222626,
-0.8895104378461838,
]
)
droid_q99 = tf.convert_to_tensor(
[
0.7597932070493698,
0.5726242214441299,
0.7351000607013702,
0.6705610305070877,
0.6464948207139969,
0.8897542208433151,
]
)
droid_norm_0_act = (
2 * (tf.zeros_like(traj["action"][:, :6]) - droid_q01) / (droid_q99 - droid_q01 + 1e-8) - 1
)
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - droid_norm_0_act) > 1e-5)

View File

@@ -0,0 +1,859 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
NOTE(YL): Adapted from:
OpenVLA: https://github.com/openvla/openvla
Octo: https://github.com/octo-models/octo
transforms.py
Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment.
Transforms adopt the following structure:
Input: Dictionary of *batched* features (i.e., has leading time dimension)
Output: Dictionary `step` =>> {
"observation": {
<image_keys, depth_image_keys>
State (in chosen state representation)
},
"action": Action (in chosen action representation),
"language_instruction": str
}
"""
from typing import Any, Dict
import tensorflow as tf
from lerobot.common.datasets.push_dataset_to_hub.openx.data_utils import (
binarize_gripper_actions,
invert_gripper_actions,
rel2abs_gripper_actions,
relabel_bridge_actions,
)
def droid_baseact_transform_fn():
from lerobot.common.datasets.push_dataset_to_hub.openx.droid_utils import droid_baseact_transform
return droid_baseact_transform
def bridge_openx_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
Applies to version of Bridge V2 in Open X-Embodiment mixture.
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
"""
for key in trajectory:
if key == "traj_metadata":
continue
elif key in ["observation", "action"]:
for key2 in trajectory[key]:
trajectory[key][key2] = trajectory[key][key2][1:]
else:
trajectory[key] = trajectory[key][1:]
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
trajectory = relabel_bridge_actions(trajectory)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
Applies to original version of Bridge V2 from the official project website.
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
"""
for key in trajectory:
if key == "traj_metadata":
continue
elif key == "observation":
for key2 in trajectory[key]:
trajectory[key][key2] = trajectory[key][key2][1:]
else:
trajectory[key] = trajectory[key][1:]
trajectory["action"] = tf.concat(
[
trajectory["action"][:, :6],
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
],
axis=1,
)
trajectory = relabel_bridge_actions(trajectory)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
[
trajectory["action"][:, :6],
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
],
axis=1,
)
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
return trajectory
def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
# decode compressed state
eef_value = tf.io.decode_compressed(
trajectory["observation"]["clip_function_input/base_pose_tool_reached"],
compression_type="ZLIB",
)
eef_value = tf.io.decode_raw(eef_value, tf.float32)
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
gripper_value = tf.io.decode_compressed(
trajectory["observation"]["gripper_closed"], compression_type="ZLIB"
)
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6]
trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8]
trajectory["action"] = trajectory["action"]["rel_actions_world"]
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
tf.clip_by_value(trajectory["action"][:, -1:], 0, 1),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][
:, -1:
]
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
tf.zeros_like(trajectory["action"]["world_vector"]),
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
tf.zeros_like(trajectory["action"]["world_vector"][:, :1]),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert absolute gripper action, +1 = open, 0 = close
gripper_action = invert_gripper_actions(
tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)
)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action,
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
trajectory["language_embedding"] = trajectory["observation"]["natural_language_embedding"]
return trajectory
def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# make gripper action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"][:, None]
gripper_action = tf.clip_by_value(gripper_action, 0, 1)
gripper_action = invert_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action,
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14]
# make gripper action absolute action, +1 = open, 0 = close
gripper_action = trajectory["action"]["gripper_closedness_action"]
gripper_action = rel2abs_gripper_actions(gripper_action)
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
gripper_action[:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# default to "open" gripper
trajectory["action"] = tf.concat(
(
trajectory["action"],
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"]),
tf.ones_like(trajectory["action"][:, :1]),
),
axis=-1,
)
# decode language instruction
instruction_bytes = trajectory["observation"]["instruction"]
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[
:, 0
]
return trajectory
def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["world_vector"],
trajectory["action"]["rotation_delta"],
trajectory["action"]["gripper_closedness_action"][:, None],
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tf.zeros_like(trajectory["action"][:, :3]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:]
trajectory["action"] = trajectory["action"][..., :7]
return trajectory
def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(trajectory["action"][:, -1:]),
),
axis=-1,
)
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["state"][:, :3],
trajectory["observation"]["state"][:, 7:10],
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2]
return trajectory
def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
return trajectory
def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32)
trajectory["observation"]["depth_additional_view"] = tf.cast(
trajectory["observation"]["depth_additional_view"][..., 0], tf.float32
)
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:]
# clip gripper action, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, -8:-2],
tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1),
),
axis=-1,
)
return trajectory
def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8]
return trajectory
def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
import tensorflow_graphics.geometry.transformation as tft
trajectory["observation"]["state"] = tf.concat(
(
trajectory["observation"]["state"][:, :7],
trajectory["observation"]["state"][:, -1:],
),
axis=-1,
)
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
return trajectory
def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tf.zeros_like(trajectory["action"][:, :3]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
return trajectory
def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
return trajectory
def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"]["future/xyz_residual"][:, :3],
trajectory["action"]["future/axis_angle_residual"][:, :3],
invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)),
),
axis=-1,
)
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
return trajectory
def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = trajectory["action"][..., -7:]
return trajectory
def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["state"][:, :4],
tf.zeros_like(trajectory["observation"]["state"][:, :2]),
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :4],
tf.zeros_like(trajectory["action"][:, :2]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
trajectory["observation"]["state"] = tf.concat((
tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32),
trajectory["observation"]["pose"],
trajectory["observation"]["joint_pos"],),
axis=-1,)
"""
trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32)
return trajectory
def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32)
return trajectory
def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
tf.zeros_like(trajectory["action"][:, :1]),
),
axis=-1,
)
return trajectory
def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["end_effector_pose"][:, :4],
tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]),
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :4],
tf.zeros_like(trajectory["action"][:, :2]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6]
return trajectory
def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# invert gripper action, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(trajectory["action"][:, -1:]),
),
axis=-1,
)
return trajectory
def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
return trajectory
def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
import tensorflow_graphics.geometry.transformation as tft
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8]
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
trajectory["action"][:, 7:8],
),
axis=-1,
)
return trajectory
def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"],
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"][:, :1]),
),
axis=-1,
)
return trajectory
def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
# invert gripper action + clip, +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :6],
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
),
axis=-1,
)
return trajectory
def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6]
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7]
# dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close
trajectory["action"] = tf.concat(
(
trajectory["action"],
invert_gripper_actions(trajectory["observation"]["gripper_state"]),
),
axis=-1,
)
return trajectory
def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
import tensorflow_graphics.geometry.transformation as tft
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
trajectory["action"][:, -1:],
),
axis=-1,
)
return trajectory
def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :3],
trajectory["action"][:, -4:],
),
axis=-1,
)
return trajectory
def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["eef_state"] = tf.concat(
(
trajectory["observation"]["state"][:, :3],
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
),
axis=-1,
)
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
trajectory["action"] = trajectory["action"][..., :-1]
return trajectory
def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
trajectory["observation"]["state"] = tf.concat(
(
trajectory["observation"]["position"],
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
trajectory["observation"]["yaw"],
),
axis=-1,
)
trajectory["action"] = tf.concat(
(
trajectory["action"],
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"]),
tf.zeros_like(trajectory["action"][:, :1]),
),
axis=-1,
)
return trajectory
def fmb_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# every input feature is batched, ie has leading batch dimension
trajectory["observation"]["proprio"] = tf.concat(
(
trajectory["observation"]["eef_pose"],
trajectory["observation"]["state_gripper_pose"][..., None],
),
axis=-1,
)
return trajectory
def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# every input feature is batched, ie has leading batch dimension
trajectory["observation"]["proprio"] = trajectory["observation"]["state"]
return trajectory
def robo_set_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
# gripper action is in -1...1 --> clip to 0...1, flip
gripper_action = trajectory["action"][:, -1:]
gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
trajectory["action"] = tf.concat(
(
trajectory["action"][:, :7],
gripper_action,
),
axis=-1,
)
return trajectory
def identity_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory
# === Registry ===
OPENX_STANDARDIZATION_TRANSFORMS = {
"bridge_openx": bridge_openx_dataset_transform,
"bridge_orig": bridge_orig_dataset_transform,
"bridge_dataset": bridge_orig_dataset_transform,
"ppgm": ppgm_dataset_transform,
"ppgm_static": ppgm_dataset_transform,
"ppgm_wrist": ppgm_dataset_transform,
"fractal20220817_data": rt1_dataset_transform,
"kuka": kuka_dataset_transform,
"taco_play": taco_play_dataset_transform,
"jaco_play": jaco_play_dataset_transform,
"berkeley_cable_routing": berkeley_cable_routing_dataset_transform,
"roboturk": roboturk_dataset_transform,
"nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform,
"viola": viola_dataset_transform,
"berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform,
"toto": toto_dataset_transform,
"language_table": language_table_dataset_transform,
"columbia_cairlab_pusht_real": pusht_dataset_transform,
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform,
"nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform,
"stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform,
"austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform,
"nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform,
"maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform,
"furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform,
"cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform,
"ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform,
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform,
"austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform,
"austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform,
"bc_z": bc_z_dataset_transform,
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform,
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform,
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": identity_transform,
"utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform,
"robo_net": robo_net_dataset_transform,
"berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform,
"berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform,
"kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform,
"stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform,
"tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform,
"dlr_sara_pour_converted_externally_to_rlds": identity_transform,
"dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform,
"dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform,
"asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform,
"stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform,
"imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform,
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform,
"uiuc_d3field": uiuc_d3field_dataset_transform,
"utaustin_mutex": utaustin_mutex_dataset_transform,
"berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform,
"cmu_playing_with_food": cmu_playing_with_food_dataset_transform,
"cmu_play_fusion": playfusion_dataset_transform,
"cmu_stretch": cmu_stretch_dataset_transform,
"berkeley_gnm_recon": gnm_dataset_transform,
"berkeley_gnm_cory_hall": gnm_dataset_transform,
"berkeley_gnm_sac_son": gnm_dataset_transform,
"droid": droid_baseact_transform_fn(),
"droid_100": droid_baseact_transform_fn(), # first 100 episodes of droid
"fmb": fmb_transform,
"dobbe": dobbe_dataset_transform,
"robo_set": robo_set_dataset_transform,
"usc_cloth_sim_converted_externally_to_rlds": identity_transform,
"plex_robosuite": identity_transform,
"conq_hose_manipulation": identity_transform,
"io_ai_tech": identity_transform,
"spoc": identity_transform,
}

View File

@@ -14,16 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
For all datasets in the RLDS format.
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
Example:
python lerobot/scripts/push_dataset_to_hub.py \
--raw-dir /path/to/data/bridge_dataset/1.0.0/ \
--repo-id your_hub/sampled_bridge_data_v2 \
--raw-format rlds \
--raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \
--repo-id youliangtan/sampled_bridge_data_v2 \
--raw-format openx_rlds.bridge_orig \
--episodes 3 4 5 8 9
Exact dataset fps defined in openx/config.py, obtained from:
@@ -38,21 +35,28 @@ import tensorflow as tf
import tensorflow_datasets as tfds
import torch
import tqdm
import yaml
from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
get_default_encoding,
save_images_concurrently,
)
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
with open("lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml") as f:
_openx_list = yaml.safe_load(f)
OPENX_DATASET_CONFIGS = _openx_list["OPENX_DATASET_CONFIGS"]
np.set_printoptions(precision=2)
@@ -104,6 +108,7 @@ def load_from_raw(
video: bool,
episodes: list[int] | None = None,
encoding: dict | None = None,
openx_dataset_name: str | None = None,
):
"""
Args:
@@ -131,23 +136,18 @@ def load_from_raw(
# we will apply the standardization transform if the dataset_name is provided
# if the dataset name is not provided and the goal is to convert any rlds formatted dataset
# search for 'image' keys in the observations
image_keys = []
state_keys = []
observation_info = dataset_info.features["steps"]["observation"]
for key in observation_info:
# check whether the key is for an image or a vector observation
if len(observation_info[key].shape) == 3:
# only adding uint8 images discards depth images
if observation_info[key].dtype == tf.uint8:
image_keys.append(key)
else:
state_keys.append(key)
if openx_dataset_name is not None:
print(" - applying standardization transform for dataset: ", openx_dataset_name)
assert openx_dataset_name in OPENX_STANDARDIZATION_TRANSFORMS
transform_fn = OPENX_STANDARDIZATION_TRANSFORMS[openx_dataset_name]
dataset = dataset.map(transform_fn)
lang_key = (
"language_instruction"
if "language_instruction" in dataset.element_spec
else None
)
image_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["image_obs_keys"]
else:
obs_keys = dataset_info.features["steps"]["observation"].keys()
image_keys = [key for key in obs_keys if "image" in key]
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
print(" - image_keys: ", image_keys)
print(" - lang_key: ", lang_key)
@@ -193,33 +193,50 @@ def load_from_raw(
num_frames = episode["action"].shape[0]
ep_dict = {}
for key in state_keys:
ep_dict[f"observation.{key}"] = tf_to_torch(episode["observation"][key])
###########################################################
# Handle the episodic data
ep_dict["action"] = tf_to_torch(episode["action"])
ep_dict["next.reward"] = tf_to_torch(episode["reward"]).float()
ep_dict["next.done"] = tf_to_torch(episode["is_last"])
ep_dict["is_terminal"] = tf_to_torch(episode["is_terminal"])
ep_dict["is_first"] = tf_to_torch(episode["is_first"])
ep_dict["discount"] = tf_to_torch(episode["discount"])
# last step of demonstration is considered done
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
ep_dict = {}
langs = [] # TODO: might be located in "observation"
image_array_dict = {key: [] for key in image_keys}
# We will create the state observation tensor by stacking the state
# obs keys defined in the openx/configs.py
if openx_dataset_name is not None:
state_obs_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["state_obs_keys"]
# stack the state observations, if is None, pad with zeros
states = []
for key in state_obs_keys:
if key in episode["observation"]:
states.append(tf_to_torch(episode["observation"][key]))
else:
states.append(torch.zeros(num_frames, 1)) # pad with zeros
states = torch.cat(states, dim=1)
# assert states.shape == (num_frames, 8), f"states shape: {states.shape}"
else:
states = tf_to_torch(episode["observation"]["state"])
actions = tf_to_torch(episode["action"])
rewards = tf_to_torch(episode["reward"]).float()
# If lang_key is present, convert the entire tensor at once
if lang_key is not None:
ep_dict["language_instruction"] = [
x.numpy().decode("utf-8") for x in episode[lang_key]
]
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
image_array_dict = {key: [] for key in image_keys}
langs = [str(x) for x in episode[lang_key]]
for im_key in image_keys:
imgs = episode["observation"][im_key]
image_array_dict[im_key] = [tf_img_convert(img) for img in imgs]
# simple assertions
for item in [states, actions, rewards, done]:
assert len(item) == num_frames
###########################################################
# loop through all cameras
for im_key in image_keys:
img_key = f"observation.images.{im_key}"
@@ -240,12 +257,22 @@ def load_from_raw(
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps}
for i in range(num_frames)
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
if lang_key is not None:
ep_dict["language_instruction"] = langs
ep_dict["observation.state"] = states
ep_dict["action"] = actions
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["next.reward"] = rewards
ep_dict["next.done"] = done
path_ep_dict = tmp_ep_dicts_dir.joinpath(
"ep_dict_" + "0" * (10 - len(str(ep_idx))) + str(ep_idx) + ".pt"
)
@@ -263,30 +290,30 @@ def load_from_raw(
def to_hf_dataset(data_dict, video) -> Dataset:
features = {}
for key in data_dict:
# check if vector state obs
if key.startswith("observation.") and "observation.images." not in key:
features[key] = Sequence(
length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)
)
# check if image obs
elif "observation.images." in key:
if video:
features[key] = VideoFrame()
else:
features[key] = Image()
keys = [key for key in data_dict if "observation.images." in key]
for key in keys:
if video:
features[key] = VideoFrame()
else:
features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
)
if "language_instruction" in data_dict:
features["language_instruction"] = Value(dtype="string", id=None)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
)
features["is_terminal"] = Value(dtype="bool", id=None)
features["is_first"] = Value(dtype="bool", id=None)
features["discount"] = Value(dtype="float32", id=None)
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
@@ -306,8 +333,19 @@ def from_raw_to_lerobot_format(
video: bool = True,
episodes: list[int] | None = None,
encoding: dict | None = None,
openx_dataset_name: str | None = None,
):
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
"""This is a test impl for rlds conversion"""
if openx_dataset_name is None:
# set a default rlds frame rate if the dataset is not from openx
fps = 30
elif "fps" not in OPENX_DATASET_CONFIGS[openx_dataset_name]:
raise ValueError(
"fps for this dataset is not specified in openx/configs.py yet," "means it is not yet tested"
)
fps = OPENX_DATASET_CONFIGS[openx_dataset_name]["fps"]
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding, openx_dataset_name)
hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {

View File

@@ -27,12 +27,12 @@ from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
get_default_encoding,
save_images_concurrently,
)
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
@@ -56,9 +56,7 @@ def check_format(raw_dir):
required_datasets.remove("meta/episode_ends")
assert all(
nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets
)
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
def load_from_raw(
@@ -78,9 +76,7 @@ def load_from_raw(
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
except ModuleNotFoundError as e:
print(
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
)
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
raise e
# as define in gmy-pusht env: https://github.com/huggingface/gym-pusht/blob/e0684ff988d223808c0a9dcfaba9dc4991791370/gym_pusht/envs/pusht.py#L174
success_threshold = 0.95 # 95% coverage,
@@ -154,9 +150,7 @@ def load_from_raw(
]
space.add(*walls)
block_body, block_shapes = PushTEnv.add_tee(
space, block_pos[i].tolist(), block_angle[i].item()
)
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
@@ -165,9 +159,7 @@ def load_from_raw(
reward[i] = np.clip(coverage / success_threshold, 0, 1)
success[i] = coverage > success_threshold
if keypoints_instead_of_image:
keypoints[i] = torch.from_numpy(
PushTEnv.get_keypoints(block_shapes).flatten()
)
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
# last step of demonstration is considered done
done[-1] = True
@@ -192,8 +184,7 @@ def load_from_raw(
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps}
for i in range(num_frames)
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
@@ -202,9 +193,7 @@ def load_from_raw(
if keypoints_instead_of_image:
ep_dict["observation.environment_state"] = keypoints
ep_dict["action"] = actions[from_idx:to_idx]
ep_dict["episode_index"] = torch.tensor(
[ep_idx] * num_frames, dtype=torch.int64
)
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
# ep_dict["next.observation.image"] = image[1:],
@@ -231,8 +220,7 @@ def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
features["observation.image"] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1],
feature=Value(dtype="float32", id=None),
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
if keypoints_instead_of_image:
features["observation.environment_state"] = Sequence(
@@ -273,9 +261,7 @@ def from_raw_to_lerobot_format(
if fps is None:
fps = 10
data_dict = load_from_raw(
raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding
)
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding)
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {

View File

@@ -26,16 +26,14 @@ from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import (
register_codecs,
)
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
get_default_encoding,
save_images_concurrently,
)
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
@@ -63,9 +61,7 @@ def check_format(raw_dir) -> bool:
nb_frames = zarr_data["data/camera0_rgb"].shape[0]
required_datasets.remove("meta/episode_ends")
assert all(
nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets
)
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
def load_from_raw(
@@ -83,9 +79,7 @@ def load_from_raw(
end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
eff_rot_axis_angle = torch.from_numpy(
zarr_data["data/robot0_eef_rot_axis_angle"][:]
)
eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
@@ -135,31 +129,24 @@ def load_from_raw(
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
encode_video_frames(
tmp_imgs_dir, video_path, fps, **(encoding or {})
)
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps}
for i in range(num_frames)
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = state
ep_dict["episode_index"] = torch.tensor(
[ep_idx] * num_frames, dtype=torch.int64
)
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
ep_dict["episode_data_index_to"] = torch.tensor(
[from_idx + num_frames] * num_frames
)
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
@@ -185,8 +172,7 @@ def to_hf_dataset(data_dict, video):
features["observation.image"] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1],
feature=Value(dtype="float32", id=None),
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
@@ -206,8 +192,7 @@ def to_hf_dataset(data_dict, video):
length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
)
features["gripper_width"] = Sequence(
length=data_dict["gripper_width"].shape[1],
feature=Value(dtype="float32", id=None),
length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
)
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))

View File

@@ -16,9 +16,7 @@
import inspect
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict
import datasets
import numpy
import PIL
import torch
@@ -45,9 +43,7 @@ def concatenate_episodes(ep_dicts):
return data_dict
def save_images_concurrently(
imgs_array: numpy.array, out_dir: Path, max_workers: int = 4
):
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
@@ -57,10 +53,7 @@ def save_images_concurrently(
num_images = len(imgs_array)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
[
executor.submit(save_image, imgs_array[i], i, out_dir)
for i in range(num_images)
]
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
def get_default_encoding() -> dict:
@@ -69,8 +62,7 @@ def get_default_encoding() -> dict:
return {
k: v.default
for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty
and k in ["vcodec", "pix_fmt", "g", "crf"]
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
}
@@ -80,60 +72,3 @@ def check_repo_id(repo_id: str) -> None:
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
)
# TODO(aliberts): remove
def calculate_episode_data_index(
hf_dataset: datasets.Dataset,
) -> Dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
Parameters:
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
Returns:
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
- "from": A tensor containing the starting index of each episode.
- "to": A tensor containing the ending index of each episode.
"""
episode_data_index = {"from": [], "to": []}
current_episode = None
"""
The episode_index is a list of integers, each representing the episode index of the corresponding example.
For instance, the following is a valid episode_index:
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
{
"from": [0, 3, 7],
"to": [3, 7, 12]
}
"""
if len(hf_dataset) == 0:
episode_data_index = {
"from": torch.tensor([]),
"to": torch.tensor([]),
}
return episode_data_index
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode:
# We encountered a new episode, so we append its starting location to the "from" list
episode_data_index["from"].append(idx)
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
if current_episode is not None:
episode_data_index["to"].append(idx)
# Let's keep track of the current episode index
current_episode = episode_idx
else:
# We are still in the same episode, so there is nothing for us to do here
pass
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
episode_data_index["to"].append(idx + 1)
for k in ["from", "to"]:
episode_data_index[k] = torch.tensor(episode_data_index[k])
return episode_data_index

View File

@@ -27,12 +27,12 @@ from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
get_default_encoding,
save_images_concurrently,
)
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
@@ -40,10 +40,7 @@ from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def check_format(raw_dir):
keys = {"actions", "rewards", "dones"}
nested_keys = {
"observations": {"rgb", "state"},
"next_observations": {"rgb", "state"},
}
nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}}
xarm_files = list(raw_dir.glob("*.pkl"))
assert len(xarm_files) > 0
@@ -56,17 +53,11 @@ def check_format(raw_dir):
# Check for consistent lengths in nested keys
expected_len = len(dataset_dict["actions"])
assert all(
len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict
)
assert all(len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict)
for key, subkeys in nested_keys.items():
nested_dict = dataset_dict.get(key, {})
assert all(
len(nested_dict[subkey]) == expected_len
for subkey in subkeys
if subkey in nested_dict
)
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
def load_from_raw(
@@ -131,18 +122,13 @@ def load_from_raw(
shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps}
for i in range(num_frames)
]
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = state
ep_dict["action"] = action
ep_dict["episode_index"] = torch.tensor(
[ep_idx] * num_frames, dtype=torch.int64
)
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
# ep_dict["next.observation.image"] = next_image
@@ -167,8 +153,7 @@ def to_hf_dataset(data_dict, video):
features["observation.image"] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1],
feature=Value(dtype="float32", id=None),
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)

View File

@@ -43,10 +43,7 @@ class EpisodeAwareSampler:
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
indices.extend(
range(
start_index.item() + drop_n_first_frames,
end_index.item() - drop_n_last_frames,
)
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
)
self.indices = indices

View File

@@ -57,9 +57,7 @@ class RandomSubsetApply(Transform):
elif not isinstance(n_subset, int):
raise TypeError("n_subset should be an int or None")
elif not (1 <= n_subset <= len(transforms)):
raise ValueError(
f"n_subset should be in the interval [1, {len(transforms)}]"
)
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
self.transforms = transforms
total = sum(p)
@@ -118,22 +116,16 @@ class SharpnessJitter(Transform):
def _check_input(self, sharpness):
if isinstance(sharpness, (int, float)):
if sharpness < 0:
raise ValueError(
"If sharpness is a single number, it must be non negative."
)
raise ValueError("If sharpness is a single number, it must be non negative.")
sharpness = [1.0 - sharpness, 1.0 + sharpness]
sharpness[0] = max(sharpness[0], 0.0)
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
sharpness = [float(v) for v in sharpness]
else:
raise TypeError(
f"{sharpness=} should be a single number or a sequence with length 2."
)
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
if not 0.0 <= sharpness[0] <= sharpness[1]:
raise ValueError(
f"sharpnesss values should be between (0., inf), but got {sharpness}."
)
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
return float(sharpness[0]), float(sharpness[1])
@@ -142,9 +134,7 @@ class SharpnessJitter(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
return self._call_kernel(
F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor
)
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def get_image_transforms(
@@ -160,10 +150,6 @@ def get_image_transforms(
sharpness_min_max: tuple[float, float] | None = None,
max_num_transforms: int | None = None,
random_order: bool = False,
interpolation: str | None = None,
image_size: tuple[int, int] | None = None,
image_mean: list[float] | None = None,
image_std: list[float] | None = None,
):
def check_value(name, weight, min_max):
if min_max is not None:
@@ -184,22 +170,6 @@ def get_image_transforms(
weights = []
transforms = []
if image_size is not None:
interpolations = [interpolation.value for interpolation in v2.InterpolationMode]
if interpolation is None:
# Use BICUBIC as default interpolation
interpolation_mode = v2.InterpolationMode.BICUBIC
elif interpolation in interpolations:
interpolation_mode = v2.InterpolationMode(interpolation)
else:
raise ValueError("The interpolation passed is not supported")
# Weight for resizing is always 1
weights.append(1.0)
transforms.append(
v2.Resize(
size=(image_size[0], image_size[1]), interpolation=interpolation_mode
)
)
if brightness_min_max is not None and brightness_weight > 0.0:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
@@ -215,15 +185,6 @@ def get_image_transforms(
if sharpness_min_max is not None and sharpness_weight > 0.0:
weights.append(sharpness_weight)
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
if image_mean is not None and image_std is not None:
# Weight for normalization is always 1
weights.append(1.0)
transforms.append(
v2.Normalize(
mean=image_mean,
std=image_std,
)
)
n_subset = len(transforms)
if max_num_transforms is not None:
@@ -233,6 +194,4 @@ def get_image_transforms(
return v2.Identity()
else:
# TODO(rcadene, aliberts): add v2.ToDtype float16?
return RandomSubsetApply(
transforms, p=weights, n_subset=n_subset, random_order=random_order
)
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)

View File

@@ -13,66 +13,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.resources
import json
import logging
import textwrap
from collections.abc import Iterator
from itertools import accumulate
import re
import warnings
from functools import cache
from pathlib import Path
from pprint import pformat
from types import SimpleNamespace
from typing import Any
from typing import Dict
import datasets
import jsonlines
import numpy as np
import pyarrow.compute as pc
import torch
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from datasets import load_dataset, load_from_disk
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
from lerobot.common.robot_devices.robots.utils import Robot
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
INFO_PATH = "meta/info.json"
EPISODES_PATH = "meta/episodes.jsonl"
STATS_PATH = "meta/stats.json"
TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = (
"videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
)
DEFAULT_PARQUET_PATH = (
"data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
)
DEFAULT_IMAGE_PATH = (
"images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
)
DATASET_CARD_TEMPLATE = """
---
# Metadata will go there
---
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
## {}
"""
DEFAULT_FEATURES = {
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
"index": {"dtype": "int64", "shape": (1,), "names": None},
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
}
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
def flatten_dict(d, parent_key="", sep="/"):
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
For example:
@@ -91,7 +56,7 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
return dict(items)
def unflatten_dict(d: dict, sep: str = "/") -> dict:
def unflatten_dict(d, sep="/"):
outdict = {}
for key, value in d.items():
parts = key.split(sep)
@@ -104,89 +69,6 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
return outdict
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
serialized_dict = {
key: value.tolist() for key, value in flatten_dict(stats).items()
}
return unflatten_dict(serialized_dict)
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
# Embed image bytes into the table before saving to parquet
format = dataset.format
dataset = dataset.with_format("arrow")
dataset = dataset.map(embed_table_storage, batched=False)
dataset = dataset.with_format(**format)
dataset.to_parquet(fpath)
def load_json(fpath: Path) -> Any:
with open(fpath) as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
def load_jsonlines(fpath: Path) -> list[Any]:
with jsonlines.open(fpath, "r") as reader:
return list(reader)
def write_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)
def append_jsonlines(data: dict, fpath: Path) -> None:
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "a") as writer:
writer.write(data)
def load_info(local_dir: Path) -> dict:
info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values():
ft["shape"] = tuple(ft["shape"])
return info
def load_stats(local_dir: Path) -> dict:
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_tasks(local_dir: Path) -> dict:
tasks = load_jsonlines(local_dir / TASKS_PATH)
return {
item["task_index"]: item["task"]
for item in sorted(tasks, key=lambda x: x["task_index"])
}
def load_episodes(local_dir: Path) -> dict:
return load_jsonlines(local_dir / EPISODES_PATH)
def load_image_as_numpy(
fpath: str | Path, dtype="float32", channel_first: bool = True
) -> np.ndarray:
img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype)
if channel_first: # (H, W, C) -> (C, H, W)
img_array = np.transpose(img_array, (2, 0, 1))
if "float" in dtype:
img_array /= 255.0
return img_array
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
@@ -198,6 +80,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif isinstance(first_item, str):
# TODO (michel-aractingi): add str2embedding via language tokenizer
# For now we leave this part up to the user to choose how to address
# language conditioned tasks
pass
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
# video frame will be processed downstream
pass
elif first_item is None:
pass
else:
@@ -205,70 +95,19 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
return items_dict
def _get_major_minor(version: str) -> tuple[int]:
split = version.strip("v").split(".")
return int(split[0]), int(split[1])
class BackwardCompatibilityError(Exception):
def __init__(self, repo_id, version):
message = textwrap.dedent(f"""
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
""")
super().__init__(message)
def check_version_compatibility(
repo_id: str,
version_to_check: str,
current_version: str,
enforce_breaking_major: bool = True,
) -> None:
current_major, _ = _get_major_minor(current_version)
major_to_check, _ = _get_major_minor(version_to_check)
if major_to_check < current_major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, version_to_check)
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
logging.warning(
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
codebase. The current codebase version is {current_version}. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
)
def get_hub_safe_version(repo_id: str, version: str) -> str:
@cache
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if version not in branches:
num_version = float(version.strip("v"))
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
raise BackwardCompatibilityError(repo_id, version)
logging.warning(
warnings.warn(
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
codebase. The following versions are available: {branches}.
The requested version ('{version}') is not found. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
stacklevel=1,
)
if "main" not in branches:
raise ValueError(f"Version 'main' not found on {repo_id}")
@@ -277,200 +116,275 @@ def get_hub_safe_version(repo_id: str, version: str) -> str:
return version
def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
elif ft["shape"] == (1,):
hf_features[key] = datasets.Value(dtype=ft["dtype"])
else:
assert len(ft["shape"]) == 1
hf_features[key] = datasets.Sequence(
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
)
# TODO: (alibers, azouitine) Add support for ft["shap"] == 0 as Value
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
# TODO(rcadene): clean this which enables getting a subset of dataset
if split != "train":
if "%" in split:
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
match_from = re.search(r"train\[(\d+):\]", split)
match_to = re.search(r"train\[:(\d+)\]", split)
if match_from:
from_frame_index = int(match_from.group(1))
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
elif match_to:
to_frame_index = int(match_to.group(1))
hf_dataset = hf_dataset.select(range(to_frame_index))
else:
raise ValueError(
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
)
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
return datasets.Features(hf_features)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
camera_ft = {}
if robot.cameras:
camera_ft = {
key: {"dtype": "video" if use_videos else "image", **ft}
for key, ft in robot.camera_features.items()
}
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
```python
from_id = episode_data_index["from"][episode_id].item()
to_id = episode_data_index["to"][episode_id].item()
episode_frames = [dataset[i] for i in range(from_id, to_id)]
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
)
return load_file(path)
def create_empty_dataset_info(
codebase_version: str,
fps: int,
robot_type: str,
features: dict,
use_videos: bool,
) -> dict:
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
"total_episodes": 0,
"total_frames": 0,
"total_tasks": 0,
"total_videos": 0,
"total_chunks": 0,
"chunks_size": DEFAULT_CHUNK_SIZE,
"fps": fps,
"splits": {},
"data_path": DEFAULT_PARQUET_PATH,
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"features": features,
}
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
```python
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
)
stats = load_file(path)
return unflatten_dict(stats)
def get_episode_data_index(
episode_dicts: list[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]:
episode_lengths = {
ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)
}
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
def load_info(repo_id, version, root) -> dict:
"""info contains useful information regarding the dataset that are not stored elsewhere
cumulative_lenghts = list(accumulate(episode_lengths.values()))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
}
Example:
```python
print("frame per second used to collect the video", info["fps"])
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "info.json"
else:
safe_version = get_hf_dataset_safe_version(repo_id, version)
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
with open(path) as f:
info = json.load(f)
return info
def calculate_total_episode(
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
) -> dict[str, torch.Tensor]:
episode_indices = sorted(hf_dataset.unique("episode_index"))
total_episodes = len(episode_indices)
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
raise ValueError("episode_index values are not sorted and contiguous.")
return total_episodes
def load_videos(repo_id, version, root) -> Path:
if root is not None:
path = Path(root) / repo_id / "videos"
else:
# TODO(rcadene): we download the whole repo here. see if we can avoid this
safe_version = get_hf_dataset_safe_version(repo_id, version)
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
path = Path(repo_dir) / "videos"
return path
def calculate_episode_data_index(
hf_dataset: datasets.Dataset,
) -> dict[str, torch.Tensor]:
episode_lengths = []
table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset)
for ep_idx in range(total_episodes):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
cumulative_lenghts = list(accumulate(episode_lengths))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
}
def check_timestamps_sync(
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
account for possible numerical error.
"""
timestamps = torch.stack(hf_dataset["timestamp"])
diffs = torch.diff(timestamps)
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
# We mask differences between the timestamp at the end of an episode
# and the one at the start of the next episode since these are expected
# to be outside tolerance.
mask = torch.ones(len(diffs), dtype=torch.bool)
ignored_diffs = episode_data_index["to"][:-1] - 1
mask[ignored_diffs] = False
filtered_within_tolerance = within_tolerance[mask]
if not torch.all(filtered_within_tolerance):
# Track original indices before masking
original_indices = torch.arange(len(diffs))
filtered_indices = original_indices[mask]
outside_tolerance_filtered_indices = torch.nonzero(
~filtered_within_tolerance
) # .squeeze()
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
episode_indices = torch.stack(hf_dataset["episode_index"])
outside_tolerances = []
for idx in outside_tolerance_indices:
entry = {
"timestamps": [timestamps[idx], timestamps[idx + 1]],
"diff": diffs[idx],
"episode_index": episode_indices[idx].item(),
}
outside_tolerances.append(entry)
if raise_value_error:
raise ValueError(
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
This might be due to synchronization issues with timestamps during data collection.
\n{pformat(outside_tolerances)}"""
)
return False
return True
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
actual timestamps from the dataset.
) -> dict[torch.Tensor]:
"""
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
within_tolerance = [
abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts
]
if not all(within_tolerance):
outside_tolerance[key] = [
ts
for ts, is_within in zip(delta_ts, within_tolerance, strict=True)
if not is_within
]
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
given modality (e.g. "observation.image") a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest
frames in the dataset.
if len(outside_tolerance) > 0:
if raise_value_error:
raise ValueError(
f"""
The following delta_timestamps are found outside of tolerance range.
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
their values accordingly.
\n{pformat(outside_tolerance)}
"""
)
return False
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function
populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array
is useful during batched training to not supervise actions associated to timestamps coming after the end of the
episode, or to pad the observations in a specific way. Note that by default the observation frames before the start
of the episode are the same as the first frame of the episode.
return True
Parameters:
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
modality (e.g., "timestamp", "observation.image", "action").
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
retrieved. These deltas are added to the item timestamp to form the query timestamps.
- tolerance_s (float, optional): The tolerance level (in seconds) used to determine if a data point is close enough to the query
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
smallest expected inter-frame period, but large enough to account for jitter.
Returns:
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for
each modality (e.g. "observation.image_is_pad").
Raises:
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization
issues with timestamps during data collection.
"""
# get indices of the frames associated to the episode, and their timestamps
ep_id = item["episode_index"].item()
ep_data_id_from = episode_data_index["from"][ep_id].item()
ep_data_id_to = episode_data_index["to"][ep_id].item()
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
# load timestamps
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
ep_timestamps = torch.stack(ep_timestamps)
# we make the assumption that the timestamps are sorted
ep_first_ts = ep_timestamps[0]
ep_last_ts = ep_timestamps[-1]
current_ts = item["timestamp"].item()
for key in delta_timestamps:
# get timestamps used as query to retrieve data of previous/future frames
delta_ts = delta_timestamps[key]
query_ts = current_ts + torch.tensor(delta_ts)
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
min_, argmin_ = dist.min(1)
# TODO(rcadene): synchronize timestamps + interpolation if needed
is_pad = min_ > tolerance_s
# check violated query timestamps are all outside the episode range
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
"This might be due to synchronization issues with timestamps during data collection."
)
# get dataset indices corresponding to frames to be loaded
data_ids = ep_data_ids[argmin_]
# load frames modality
item[key] = hf_dataset.select_columns(key)[data_ids][key]
if isinstance(item[key][0], dict) and "path" in item[key][0]:
# video mode where frame are expressed as dict of path and timestamp
item[key] = item[key]
else:
item[key] = torch.stack(item[key])
item[f"{key}_is_pad"] = is_pad
return item
def get_delta_indices(
delta_timestamps: dict[str, list[float]], fps: int
) -> dict[str, list[int]]:
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
return delta_indices
Parameters:
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
Returns:
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
- "from": A tensor containing the starting index of each episode.
- "to": A tensor containing the ending index of each episode.
"""
episode_data_index = {"from": [], "to": []}
current_episode = None
"""
The episode_index is a list of integers, each representing the episode index of the corresponding example.
For instance, the following is a valid episode_index:
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
{
"from": [0, 3, 7],
"to": [3, 7, 12]
}
"""
if len(hf_dataset) == 0:
episode_data_index = {
"from": torch.tensor([]),
"to": torch.tensor([]),
}
return episode_data_index
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode:
# We encountered a new episode, so we append its starting location to the "from" list
episode_data_index["from"].append(idx)
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
if current_episode is not None:
episode_data_index["to"].append(idx)
# Let's keep track of the current episode index
current_episode = episode_idx
else:
# We are still in the same episode, so there is nothing for us to do here
pass
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
episode_data_index["to"].append(idx + 1)
for k in ["from", "to"]:
episode_data_index[k] = torch.tensor(episode_data_index[k])
return episode_data_index
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
"""Reset the `episode_index` of the provided HuggingFace Dataset.
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
This brings the `episode_index` to the required format.
"""
if len(hf_dataset) == 0:
return hf_dataset
unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist()
episode_idx_to_reset_idx_mapping = {
ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs)
}
def modify_ep_idx_func(example):
example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()]
return example
hf_dataset = hf_dataset.map(modify_ep_idx_func)
return hf_dataset
def cycle(iterable):
@@ -486,7 +400,7 @@ def cycle(iterable):
iterator = iter(iterable)
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
def create_branch(repo_id, *, branch: str, repo_type: str | None = None):
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
exists before creating it.
"""
@@ -501,96 +415,12 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
def create_lerobot_dataset_card(
tags: list | None = None,
dataset_info: dict | None = None,
**kwargs,
) -> DatasetCard:
"""
Keyword arguments will be used to replace values in ./lerobot/common/datasets/card_template.md.
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
"""
card_tags = ["LeRobot"]
if tags:
card_tags += tags
if dataset_info:
dataset_structure = "[meta/info.json](meta/info.json):\n"
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
kwargs = {**kwargs, "dataset_structure": dataset_structure}
card_data = DatasetCardData(
license=kwargs.get("license"),
tags=card_tags,
task_categories=["robotics"],
configs=[
{
"config_name": "default",
"data_files": "data/*/*.parquet",
}
],
)
card_template = (
importlib.resources.files("lerobot.common.datasets") / "card_template.md"
).read_text()
return DatasetCard.from_template(
card_data=card_data,
template_str=card_template,
**kwargs,
)
class IterableNamespace(SimpleNamespace):
"""
A namespace object that supports both dictionary-like iteration and dot notation access.
Automatically converts nested dictionaries into IterableNamespaces.
This class extends SimpleNamespace to provide:
- Dictionary-style iteration over keys
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
- Dictionary-like methods: items(), keys(), values()
- Recursive conversion of nested dictionaries
Args:
dictionary: Optional dictionary to initialize the namespace
**kwargs: Additional keyword arguments passed to SimpleNamespace
Examples:
>>> data = {"name": "Alice", "details": {"age": 25}}
>>> ns = IterableNamespace(data)
>>> ns.name
'Alice'
>>> ns.details.age
25
>>> list(ns.keys())
['name', 'details']
>>> for key, value in ns.items():
... print(f"{key}: {value}")
name: Alice
details: IterableNamespace(age=25)
"""
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
super().__init__(**kwargs)
if dictionary is not None:
for key, value in dictionary.items():
if isinstance(value, dict):
setattr(self, key, IterableNamespace(value))
else:
setattr(self, key, value)
def __iter__(self) -> Iterator[str]:
return iter(vars(self))
def __getitem__(self, key: str) -> Any:
return vars(self)[key]
def items(self):
return vars(self).items()
def values(self):
return vars(self).values()
def keys(self):
return vars(self).keys()
def create_lerobot_dataset_card(tags: list | None = None, text: str | None = None) -> DatasetCard:
card = DatasetCard(DATASET_CARD_TEMPLATE)
card.data.task_categories = ["robotics"]
card.data.tags = ["LeRobot"]
if tags is not None:
card.data.tags += tags
if text is not None:
card.text += text
return card

View File

@@ -1,924 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.
Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in
lerobot/configs/robot/aloha.yaml before running this script.
"""
import traceback
from pathlib import Path
from textwrap import dedent
from lerobot import available_datasets
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import (
convert_dataset,
parse_robot_config,
)
LOCAL_DIR = Path("data/")
ALOHA_CONFIG = Path("lerobot/configs/robot/aloha.yaml")
ALOHA_MOBILE_INFO = {
"robot_config": parse_robot_config(ALOHA_CONFIG),
"license": "mit",
"url": "https://mobile-aloha.github.io/",
"paper": "https://arxiv.org/abs/2401.02117",
"citation_bibtex": dedent(r"""
@inproceedings{fu2024mobile,
author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea},
title = {Mobile ALOHA: Learning Bimanual Mobile Manipulation with Low-Cost Whole-Body Teleoperation},
booktitle = {arXiv},
year = {2024},
}""").lstrip(),
}
ALOHA_STATIC_INFO = {
"robot_config": parse_robot_config(ALOHA_CONFIG),
"license": "mit",
"url": "https://tonyzhaozh.github.io/aloha/",
"paper": "https://arxiv.org/abs/2304.13705",
"citation_bibtex": dedent(r"""
@article{Zhao2023LearningFB,
title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
author={Tony Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
journal={RSS},
year={2023},
volume={abs/2304.13705},
url={https://arxiv.org/abs/2304.13705}
}""").lstrip(),
}
PUSHT_INFO = {
"license": "mit",
"url": "https://diffusion-policy.cs.columbia.edu/",
"paper": "https://arxiv.org/abs/2303.04137v5",
"citation_bibtex": dedent(r"""
@article{chi2024diffusionpolicy,
author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
journal = {The International Journal of Robotics Research},
year = {2024},
}""").lstrip(),
}
XARM_INFO = {
"license": "mit",
"url": "https://www.nicklashansen.com/td-mpc/",
"paper": "https://arxiv.org/abs/2203.04955",
"citation_bibtex": dedent(r"""
@inproceedings{Hansen2022tdmpc,
title={Temporal Difference Learning for Model Predictive Control},
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
booktitle={ICML},
year={2022}
}
"""),
}
UNITREEH_INFO = {
"license": "apache-2.0",
}
DATASETS = {
"aloha_mobile_cabinet": {
"single_task": "Open the top cabinet, store the pot inside it then close the cabinet.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_chair": {
"single_task": "Push the chairs in front of the desk to place them against it.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_elevator": {
"single_task": "Take the elevator to the 1st floor.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_shrimp": {
"single_task": "Sauté the raw shrimp on both sides, then serve it in the bowl.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_wash_pan": {
"single_task": "Pick up the pan, rinse it in the sink and then place it in the drying rack.",
**ALOHA_MOBILE_INFO,
},
"aloha_mobile_wipe_wine": {
"single_task": "Pick up the wet cloth on the faucet and use it to clean the spilled wine on the table and underneath the glass.",
**ALOHA_MOBILE_INFO,
},
"aloha_static_battery": {
"single_task": "Place the battery into the slot of the remote controller.",
**ALOHA_STATIC_INFO,
},
"aloha_static_candy": {
"single_task": "Pick up the candy and unwrap it.",
**ALOHA_STATIC_INFO,
},
"aloha_static_coffee": {
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
**ALOHA_STATIC_INFO,
},
"aloha_static_coffee_new": {
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray.",
**ALOHA_STATIC_INFO,
},
"aloha_static_cups_open": {
"single_task": "Pick up the plastic cup and open its lid.",
**ALOHA_STATIC_INFO,
},
"aloha_static_fork_pick_up": {
"single_task": "Pick up the fork and place it on the plate.",
**ALOHA_STATIC_INFO,
},
"aloha_static_pingpong_test": {
"single_task": "Transfer one of the two balls in the right glass into the left glass, then transfer it back to the right glass.",
**ALOHA_STATIC_INFO,
},
"aloha_static_pro_pencil": {
"single_task": "Pick up the pencil with the right arm, hand it over to the left arm then place it back onto the table.",
**ALOHA_STATIC_INFO,
},
"aloha_static_screw_driver": {
"single_task": "Pick up the screwdriver with the right arm, hand it over to the left arm then place it into the cup.",
**ALOHA_STATIC_INFO,
},
"aloha_static_tape": {
"single_task": "Cut a small piece of tape from the tape dispenser then place it on the cardboard box's edge.",
**ALOHA_STATIC_INFO,
},
"aloha_static_thread_velcro": {
"single_task": "Pick up the velcro cable tie with the left arm, then insert the end of the velcro tie into the other end's loop with the right arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_towel": {
"single_task": "Pick up a piece of paper towel and place it on the spilled liquid.",
**ALOHA_STATIC_INFO,
},
"aloha_static_vinh_cup": {
"single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_vinh_cup_left": {
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_ziploc_slide": {
"single_task": "Slide open the ziploc bag.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_scripted": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_scripted_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_human": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_human_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_scripted": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_scripted_image": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_human": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_transfer_cube_human_image": {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"pusht": {
"single_task": "Push the T-shaped block onto the T-shaped target.",
**PUSHT_INFO,
},
"pusht_image": {
"single_task": "Push the T-shaped block onto the T-shaped target.",
**PUSHT_INFO,
},
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
"unitreeh1_rearrange_objects": {
"single_task": "Put the object into the bin.",
**UNITREEH_INFO,
},
"unitreeh1_two_robot_greeting": {
"single_task": "Greet the other robot with a high five.",
**UNITREEH_INFO,
},
"unitreeh1_warehouse": {
"single_task": "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.",
**UNITREEH_INFO,
},
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_image": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_lift_medium_replay": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_lift_medium_replay_image": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_image": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"xarm_push_medium_replay": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"xarm_push_medium_replay_image": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"umi_cup_in_the_wild": {
"single_task": "Put the cup on the plate.",
"license": "apache-2.0",
},
"asu_table_top": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://link.springer.com/article/10.1007/s10514-023-10129-1",
"citation_bibtex": dedent(r"""
@inproceedings{zhou2023modularity,
title={Modularity through Attention: Efficient Training and Transfer of Language-Conditioned Policies for Robot Manipulation},
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Stepputtis, Simon and Amor, Heni},
booktitle={Conference on Robot Learning},
pages={1684--1695},
year={2023},
organization={PMLR}
}
@article{zhou2023learning,
title={Learning modular language-conditioned robot policies through attention},
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Ben Amor, Heni and Stepputtis, Simon},
journal={Autonomous Robots},
pages={1--21},
year={2023},
publisher={Springer}
}""").lstrip(),
},
"austin_buds_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/BUDS-website/",
"paper": "https://arxiv.org/abs/2109.13841",
"citation_bibtex": dedent(r"""
@article{zhu2022bottom,
title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation},
author={Zhu, Yifeng and Stone, Peter and Zhu, Yuke},
journal={IEEE Robotics and Automation Letters},
volume={7},
number={2},
pages={4126--4133},
year={2022},
publisher={IEEE}
}""").lstrip(),
},
"austin_sailor_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/sailor/",
"paper": "https://arxiv.org/abs/2210.11435",
"citation_bibtex": dedent(r"""
@inproceedings{nasiriany2022sailor,
title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning},
author={Soroush Nasiriany and Tian Gao and Ajay Mandlekar and Yuke Zhu},
booktitle={Conference on Robot Learning (CoRL)},
year={2022}
}""").lstrip(),
},
"austin_sirius_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/sirius/",
"paper": "https://arxiv.org/abs/2211.08416",
"citation_bibtex": dedent(r"""
@inproceedings{liu2022robot,
title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment},
author = {Huihan Liu and Soroush Nasiriany and Lance Zhang and Zhiyao Bao and Yuke Zhu},
booktitle = {Robotics: Science and Systems (RSS)},
year = {2023}
}""").lstrip(),
},
"berkeley_autolab_ur5": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://sites.google.com/view/berkeley-ur5/home",
"citation_bibtex": dedent(r"""
@misc{BerkeleyUR5Website,
title = {Berkeley {UR5} Demonstration Dataset},
author = {Lawrence Yunliang Chen and Simeon Adebola and Ken Goldberg},
howpublished = {https://sites.google.com/view/berkeley-ur5/home},
}""").lstrip(),
},
"berkeley_cable_routing": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://sites.google.com/view/cablerouting/home",
"paper": "https://arxiv.org/abs/2307.08927",
"citation_bibtex": dedent(r"""
@article{luo2023multistage,
author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine},
title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning},
journal = {arXiv pre-print},
year = {2023},
url = {https://arxiv.org/abs/2307.08927},
}""").lstrip(),
},
"berkeley_fanuc_manipulation": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/berkeley.edu/fanuc-manipulation",
"citation_bibtex": dedent(r"""
@article{fanuc_manipulation2023,
title={Fanuc Manipulation: A Dataset for Learning-based Manipulation with FANUC Mate 200iD Robot},
author={Zhu, Xinghao and Tian, Ran and Xu, Chenfeng and Ding, Mingyu and Zhan, Wei and Tomizuka, Masayoshi},
year={2023},
}""").lstrip(),
},
"berkeley_gnm_cory_hall": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://arxiv.org/abs/1709.10489",
"citation_bibtex": dedent(r"""
@inproceedings{kahn2018self,
title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation},
author={Kahn, Gregory and Villaflor, Adam and Ding, Bosen and Abbeel, Pieter and Levine, Sergey},
booktitle={2018 IEEE international conference on robotics and automation (ICRA)},
pages={5129--5136},
year={2018},
organization={IEEE}
}""").lstrip(),
},
"berkeley_gnm_recon": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/recon-robot",
"paper": "https://arxiv.org/abs/2104.05859",
"citation_bibtex": dedent(r"""
@inproceedings{shah2021rapid,
title={Rapid Exploration for Open-World Navigation with Latent Goal Models},
author={Dhruv Shah and Benjamin Eysenbach and Nicholas Rhinehart and Sergey Levine},
booktitle={5th Annual Conference on Robot Learning },
year={2021},
url={https://openreview.net/forum?id=d_SWJhyKfVw}
}""").lstrip(),
},
"berkeley_gnm_sac_son": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/SACSoN-review",
"paper": "https://arxiv.org/abs/2306.01874",
"citation_bibtex": dedent(r"""
@article{hirose2023sacson,
title={SACSoN: Scalable Autonomous Data Collection for Social Navigation},
author={Hirose, Noriaki and Shah, Dhruv and Sridhar, Ajay and Levine, Sergey},
journal={arXiv preprint arXiv:2306.01874},
year={2023}
}""").lstrip(),
},
"berkeley_mvp": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://arxiv.org/abs/2203.06173",
"citation_bibtex": dedent(r"""
@InProceedings{Radosavovic2022,
title = {Real-World Robot Learning with Masked Visual Pre-training},
author = {Ilija Radosavovic and Tete Xiao and Stephen James and Pieter Abbeel and Jitendra Malik and Trevor Darrell},
booktitle = {CoRL},
year = {2022}
}""").lstrip(),
},
"berkeley_rpt": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://arxiv.org/abs/2306.10007",
"citation_bibtex": dedent(r"""
@article{Radosavovic2023,
title={Robot Learning with Sensorimotor Pre-training},
author={Ilija Radosavovic and Baifeng Shi and Letian Fu and Ken Goldberg and Trevor Darrell and Jitendra Malik},
year={2023},
journal={arXiv:2306.10007}
}""").lstrip(),
},
"cmu_franka_exploration_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://human-world-model.github.io/",
"paper": "https://arxiv.org/abs/2308.10901",
"citation_bibtex": dedent(r"""
@inproceedings{mendonca2023structured,
title={Structured World Models from Human Videos},
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
journal={RSS},
year={2023}
}""").lstrip(),
},
"cmu_play_fusion": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://play-fusion.github.io/",
"paper": "https://arxiv.org/abs/2312.04549",
"citation_bibtex": dedent(r"""
@inproceedings{chen2023playfusion,
title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play},
author={Chen, Lili and Bahl, Shikhar and Pathak, Deepak},
booktitle={CoRL},
year={2023}
}""").lstrip(),
},
"cmu_stretch": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://robo-affordances.github.io/",
"paper": "https://arxiv.org/abs/2304.08488",
"citation_bibtex": dedent(r"""
@inproceedings{bahl2023affordances,
title={Affordances from Human Videos as a Versatile Representation for Robotics},
author={Bahl, Shikhar and Mendonca, Russell and Chen, Lili and Jain, Unnat and Pathak, Deepak},
booktitle={CVPR},
year={2023}
}
@article{mendonca2023structured,
title={Structured World Models from Human Videos},
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
journal={CoRL},
year={2023}
}""").lstrip(),
},
"columbia_cairlab_pusht_real": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://diffusion-policy.cs.columbia.edu/",
"paper": "https://arxiv.org/abs/2303.04137v5",
"citation_bibtex": dedent(r"""
@inproceedings{chi2023diffusionpolicy,
title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
author={Chi, Cheng and Feng, Siyuan and Du, Yilun and Xu, Zhenjia and Cousineau, Eric and Burchfiel, Benjamin and Song, Shuran},
booktitle={Proceedings of Robotics: Science and Systems (RSS)},
year={2023}
}""").lstrip(),
},
"conq_hose_manipulation": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/conq-hose-manipulation-dataset/home",
"citation_bibtex": dedent(r"""
@misc{ConqHoseManipData,
author={Peter Mitrano and Dmitry Berenson},
title={Conq Hose Manipulation Dataset, v1.15.0},
year={2024},
howpublished={https://sites.google.com/view/conq-hose-manipulation-dataset}
}""").lstrip(),
},
"dlr_edan_shared_control": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://ieeexplore.ieee.org/document/9341156",
"citation_bibtex": dedent(r"""
@inproceedings{vogel_edan_2020,
title = {EDAN - an EMG-Controlled Daily Assistant to Help People with Physical Disabilities},
language = {en},
booktitle = {2020 {IEEE}/{RSJ} {International} {Conference} on {Intelligent} {Robots} and {Systems} ({IROS})},
author = {Vogel, Jörn and Hagengruber, Annette and Iskandar, Maged and Quere, Gabriel and Leipscher, Ulrike and Bustamante, Samuel and Dietrich, Alexander and Hoeppner, Hannes and Leidner, Daniel and Albu-Schäffer, Alin},
year = {2020}
}
@inproceedings{quere_shared_2020,
address = {Paris, France},
title = {Shared {Control} {Templates} for {Assistive} {Robotics}},
language = {en},
booktitle = {2020 {IEEE} {International} {Conference} on {Robotics} and {Automation} ({ICRA})},
author = {Quere, Gabriel and Hagengruber, Annette and Iskandar, Maged and Bustamante, Samuel and Leidner, Daniel and Stulp, Freek and Vogel, Joern},
year = {2020},
pages = {7},
}""").lstrip(),
},
"dlr_sara_grid_clamp": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://www.researchsquare.com/article/rs-3289569/v1",
"citation_bibtex": dedent(r"""
@article{padalkar2023guided,
title={A guided reinforcement learning approach using shared control templates for learning manipulation skills in the real world},
author={Padalkar, Abhishek and Quere, Gabriel and Raffin, Antonin and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
journal={Research square preprint rs-3289569/v1},
year={2023}
}""").lstrip(),
},
"dlr_sara_pour": {
"tasks_col": "language_instruction",
"license": "mit",
"paper": "https://elib.dlr.de/193739/1/padalkar2023rlsct.pdf",
"citation_bibtex": dedent(r"""
@inproceedings{padalkar2023guiding,
title={Guiding Reinforcement Learning with Shared Control Templates},
author={Padalkar, Abhishek and Quere, Gabriel and Steinmetz, Franz and Raffin, Antonin and Nieuwenhuisen, Matthias and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
booktitle={40th IEEE International Conference on Robotics and Automation, ICRA 2023},
year={2023},
organization={IEEE}
}""").lstrip(),
},
"droid_100": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://droid-dataset.github.io/",
"paper": "https://arxiv.org/abs/2403.12945",
"citation_bibtex": dedent(r"""
@article{khazatsky2024droid,
title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset},
author = {Alexander Khazatsky and Karl Pertsch and Suraj Nair and Ashwin Balakrishna and Sudeep Dasari and Siddharth Karamcheti and Soroush Nasiriany and Mohan Kumar Srirama and Lawrence Yunliang Chen and Kirsty Ellis and Peter David Fagan and Joey Hejna and Masha Itkina and Marion Lepert and Yecheng Jason Ma and Patrick Tree Miller and Jimmy Wu and Suneel Belkhale and Shivin Dass and Huy Ha and Arhan Jain and Abraham Lee and Youngwoon Lee and Marius Memmel and Sungjae Park and Ilija Radosavovic and Kaiyuan Wang and Albert Zhan and Kevin Black and Cheng Chi and Kyle Beltran Hatch and Shan Lin and Jingpei Lu and Jean Mercat and Abdul Rehman and Pannag R Sanketi and Archit Sharma and Cody Simpson and Quan Vuong and Homer Rich Walke and Blake Wulfe and Ted Xiao and Jonathan Heewon Yang and Arefeh Yavary and Tony Z. Zhao and Christopher Agia and Rohan Baijal and Mateo Guaman Castro and Daphne Chen and Qiuyu Chen and Trinity Chung and Jaimyn Drake and Ethan Paul Foster and Jensen Gao and David Antonio Herrera and Minho Heo and Kyle Hsu and Jiaheng Hu and Donovon Jackson and Charlotte Le and Yunshuang Li and Kevin Lin and Roy Lin and Zehan Ma and Abhiram Maddukuri and Suvir Mirchandani and Daniel Morton and Tony Nguyen and Abigail O'Neill and Rosario Scalise and Derick Seale and Victor Son and Stephen Tian and Emi Tran and Andrew E. Wang and Yilin Wu and Annie Xie and Jingyun Yang and Patrick Yin and Yunchu Zhang and Osbert Bastani and Glen Berseth and Jeannette Bohg and Ken Goldberg and Abhinav Gupta and Abhishek Gupta and Dinesh Jayaraman and Joseph J Lim and Jitendra Malik and Roberto Martín-Martín and Subramanian Ramamoorthy and Dorsa Sadigh and Shuran Song and Jiajun Wu and Michael C. Yip and Yuke Zhu and Thomas Kollar and Sergey Levine and Chelsea Finn},
year = {2024},
}""").lstrip(),
},
"fmb": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://functional-manipulation-benchmark.github.io/",
"paper": "https://arxiv.org/abs/2401.08553",
"citation_bibtex": dedent(r"""
@article{luo2024fmb,
title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning},
author={Luo, Jianlan and Xu, Charles and Liu, Fangchen and Tan, Liam and Lin, Zipeng and Wu, Jeffrey and Abbeel, Pieter and Levine, Sergey},
journal={arXiv preprint arXiv:2401.08553},
year={2024}
}""").lstrip(),
},
"iamlab_cmu_pickup_insert": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://openreview.net/forum?id=WuBv9-IGDUA",
"paper": "https://arxiv.org/abs/2401.14502",
"citation_bibtex": dedent(r"""
@inproceedings{saxena2023multiresolution,
title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models},
author={Saumya Saxena and Mohit Sharma and Oliver Kroemer},
booktitle={7th Annual Conference on Robot Learning},
year={2023},
url={https://openreview.net/forum?id=WuBv9-IGDUA}
}""").lstrip(),
},
"imperialcollege_sawyer_wrist_cam": {
"tasks_col": "language_instruction",
"license": "mit",
},
"jaco_play": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://github.com/clvrai/clvr_jaco_play_dataset",
"citation_bibtex": dedent(r"""
@software{dass2023jacoplay,
author = {Dass, Shivin and Yapeter, Jullian and Zhang, Jesse and Zhang, Jiahui
and Pertsch, Karl and Nikolaidis, Stefanos and Lim, Joseph J.},
title = {CLVR Jaco Play Dataset},
url = {https://github.com/clvrai/clvr_jaco_play_dataset},
version = {1.0.0},
year = {2023}
}""").lstrip(),
},
"kaist_nonprehensile": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://github.com/JaeHyung-Kim/rlds_dataset_builder",
"citation_bibtex": dedent(r"""
@article{kimpre,
title={Pre-and post-contact policy decomposition for non-prehensile manipulation with zero-shot sim-to-real transfer},
author={Kim, Minchan and Han, Junhyek and Kim, Jaehyung and Kim, Beomjoon},
booktitle={2023 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
year={2023},
organization={IEEE}
}""").lstrip(),
},
"nyu_door_opening_surprising_effectiveness": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://jyopari.github.io/VINN/",
"paper": "https://arxiv.org/abs/2112.01511",
"citation_bibtex": dedent(r"""
@misc{pari2021surprising,
title={The Surprising Effectiveness of Representation Learning for Visual Imitation},
author={Jyothish Pari and Nur Muhammad Shafiullah and Sridhar Pandian Arunachalam and Lerrel Pinto},
year={2021},
eprint={2112.01511},
archivePrefix={arXiv},
primaryClass={cs.RO}
}""").lstrip(),
},
"nyu_franka_play_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://play-to-policy.github.io/",
"paper": "https://arxiv.org/abs/2210.10047",
"citation_bibtex": dedent(r"""
@article{cui2022play,
title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data},
author = {Cui, Zichen Jeff and Wang, Yibin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
journal = {arXiv preprint arXiv:2210.10047},
year = {2022}
}""").lstrip(),
},
"nyu_rot_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://rot-robot.github.io/",
"paper": "https://arxiv.org/abs/2206.15469",
"citation_bibtex": dedent(r"""
@inproceedings{haldar2023watch,
title={Watch and match: Supercharging imitation with regularized optimal transport},
author={Haldar, Siddhant and Mathur, Vaibhav and Yarats, Denis and Pinto, Lerrel},
booktitle={Conference on Robot Learning},
pages={32--43},
year={2023},
organization={PMLR}
}""").lstrip(),
},
"roboturk": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://roboturk.stanford.edu/dataset_real.html",
"paper": "PAPER",
"citation_bibtex": dedent(r"""
@inproceedings{mandlekar2019scaling,
title={Scaling robot supervision to hundreds of hours with roboturk: Robotic manipulation dataset through human reasoning and dexterity},
author={Mandlekar, Ajay and Booher, Jonathan and Spero, Max and Tung, Albert and Gupta, Anchit and Zhu, Yuke and Garg, Animesh and Savarese, Silvio and Fei-Fei, Li},
booktitle={2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
pages={1048--1055},
year={2019},
organization={IEEE}
}""").lstrip(),
},
"stanford_hydra_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/hydra-il-2023",
"paper": "https://arxiv.org/abs/2306.17237",
"citation_bibtex": dedent(r"""
@article{belkhale2023hydra,
title={HYDRA: Hybrid Robot Actions for Imitation Learning},
author={Belkhale, Suneel and Cui, Yuchen and Sadigh, Dorsa},
journal={arxiv},
year={2023}
}""").lstrip(),
},
"stanford_kuka_multimodal_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://sites.google.com/view/visionandtouch",
"paper": "https://arxiv.org/abs/1810.10191",
"citation_bibtex": dedent(r"""
@inproceedings{lee2019icra,
title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks},
author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette},
booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)},
year={2019},
url={https://arxiv.org/abs/1810.10191}
}""").lstrip(),
},
"stanford_robocook": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://hshi74.github.io/robocook/",
"paper": "https://arxiv.org/abs/2306.14447",
"citation_bibtex": dedent(r"""
@article{shi2023robocook,
title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools},
author={Shi, Haochen and Xu, Huazhe and Clarke, Samuel and Li, Yunzhu and Wu, Jiajun},
journal={arXiv preprint arXiv:2306.14447},
year={2023}
}""").lstrip(),
},
"taco_play": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"url": "https://www.kaggle.com/datasets/oiermees/taco-robot",
"paper": "https://arxiv.org/abs/2209.08959, https://arxiv.org/abs/2210.01911",
"citation_bibtex": dedent(r"""
@inproceedings{rosete2022tacorl,
author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard},
title = {Latent Plans for Task Agnostic Offline Reinforcement Learning},
journal = {Proceedings of the 6th Conference on Robot Learning (CoRL)},
year = {2022}
}
@inproceedings{mees23hulc2,
title={Grounding Language with Visual Affordances over Unstructured Data},
author={Oier Mees and Jessica Borja-Diaz and Wolfram Burgard},
booktitle = {Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)},
year={2023},
address = {London, UK}
}""").lstrip(),
},
"tokyo_u_lsmo": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "URL",
"paper": "https://arxiv.org/abs/2107.05842",
"citation_bibtex": dedent(r"""
@Article{Osa22,
author = {Takayuki Osa},
journal = {The International Journal of Robotics Research},
title = {Motion Planning by Learning the Solution Manifold in Trajectory Optimization},
year = {2022},
number = {3},
pages = {291--311},
volume = {41},
}""").lstrip(),
},
"toto": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://toto-benchmark.org/",
"paper": "https://arxiv.org/abs/2306.00942",
"citation_bibtex": dedent(r"""
@inproceedings{zhou2023train,
author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav},
booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)},
title={Train Offline, Test Online: A Real Robot Learning Benchmark},
year={2023},
}""").lstrip(),
},
"ucsd_kitchen_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"citation_bibtex": dedent(r"""
@ARTICLE{ucsd_kitchens,
author = {Ge Yan, Kris Wu, and Xiaolong Wang},
title = {{ucsd kitchens Dataset}},
year = {2023},
month = {August}
}""").lstrip(),
},
"ucsd_pick_and_place_dataset": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://owmcorl.github.io/#",
"paper": "https://arxiv.org/abs/2310.16029",
"citation_bibtex": dedent(r"""
@preprint{Feng2023Finetuning,
title={Finetuning Offline World Models in the Real World},
author={Yunhai Feng, Nicklas Hansen, Ziyan Xiong, Chandramouli Rajagopalan, Xiaolong Wang},
year={2023}
}""").lstrip(),
},
"uiuc_d3field": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://robopil.github.io/d3fields/",
"paper": "https://arxiv.org/abs/2309.16118",
"citation_bibtex": dedent(r"""
@article{wang2023d3field,
title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation},
author={Wang, Yixuan and Li, Zhuoran and Zhang, Mingtong and Driggs-Campbell, Katherine and Wu, Jiajun and Fei-Fei, Li and Li, Yunzhu},
journal={arXiv preprint arXiv:},
year={2023},
}""").lstrip(),
},
"usc_cloth_sim": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://uscresl.github.io/dmfd/",
"paper": "https://arxiv.org/abs/2207.10148",
"citation_bibtex": dedent(r"""
@article{salhotra2022dmfd,
author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.},
journal={IEEE Robotics and Automation Letters},
title={Learning Deformable Object Manipulation From Expert Demonstrations},
year={2022},
volume={7},
number={4},
pages={8775-8782},
doi={10.1109/LRA.2022.3187843}
}""").lstrip(),
},
"utaustin_mutex": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/MUTEX/",
"paper": "https://arxiv.org/abs/2309.14320",
"citation_bibtex": dedent(r"""
@inproceedings{shah2023mutex,
title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications},
author={Rutav Shah and Roberto Mart{\'\i}n-Mart{\'\i}n and Yuke Zhu},
booktitle={7th Annual Conference on Robot Learning},
year={2023},
url={https://openreview.net/forum?id=PwqiqaaEzJ}
}""").lstrip(),
},
"utokyo_pr2_opening_fridge": {
"tasks_col": "language_instruction",
"license": "mit",
"citation_bibtex": dedent(r"""
@misc{oh2023pr2utokyodatasets,
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
title={X-Embodiment U-Tokyo PR2 Datasets},
year={2023},
url={https://github.com/ojh6404/rlds_dataset_builder},
}""").lstrip(),
},
"utokyo_pr2_tabletop_manipulation": {
"tasks_col": "language_instruction",
"license": "mit",
"citation_bibtex": dedent(r"""
@misc{oh2023pr2utokyodatasets,
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
title={X-Embodiment U-Tokyo PR2 Datasets},
year={2023},
url={https://github.com/ojh6404/rlds_dataset_builder},
}""").lstrip(),
},
"utokyo_saytap": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://saytap.github.io/",
"paper": "https://arxiv.org/abs/2306.07580",
"citation_bibtex": dedent(r"""
@article{saytap2023,
author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and
Tatsuya Harada},
title = {SayTap: Language to Quadrupedal Locomotion},
eprint = {arXiv:2306.07580},
url = {https://saytap.github.io},
note = {https://saytap.github.io},
year = {2023}
}""").lstrip(),
},
"utokyo_xarm_bimanual": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"citation_bibtex": dedent(r"""
@misc{matsushima2023weblab,
title={Weblab xArm Dataset},
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
year={2023},
}""").lstrip(),
},
"utokyo_xarm_pick_and_place": {
"tasks_col": "language_instruction",
"license": "cc-by-4.0",
"citation_bibtex": dedent(r"""
@misc{matsushima2023weblab,
title={Weblab xArm Dataset},
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
year={2023},
}""").lstrip(),
},
"viola": {
"tasks_col": "language_instruction",
"license": "mit",
"url": "https://ut-austin-rpl.github.io/VIOLA/",
"paper": "https://arxiv.org/abs/2210.11339",
"citation_bibtex": dedent(r"""
@article{zhu2022viola,
title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors},
author={Zhu, Yifeng and Joshi, Abhishek and Stone, Peter and Zhu, Yuke},
journal={6th Annual Conference on Robot Learning (CoRL)},
year={2022}
}""").lstrip(),
},
}
def batch_convert():
status = {}
logfile = LOCAL_DIR / "conversion_log.txt"
assert set(DATASETS) == {id_.split("/")[1] for id_ in available_datasets}
for num, (name, kwargs) in enumerate(DATASETS.items()):
repo_id = f"lerobot/{name}"
print(f"\nConverting {repo_id} ({num}/{len(DATASETS)})")
print("---------------------------------------------------------")
try:
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
status = f"{repo_id}: success."
with open(logfile, "a") as file:
file.write(status + "\n")
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file:
file.write(status + "\n")
continue
if __name__ == "__main__":
batch_convert()

View File

@@ -1,774 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
We support 3 different scenarios for these tasks (see instructions below):
1. Single task dataset: all episodes of your dataset have the same single task.
2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
one episode to the next.
3. Multi task episodes: episodes of your dataset may each contain several different tasks.
Can you can also provide a robot config .yaml file (not mandatory) to this script via the option
'--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was
recorded with. For now, only Aloha/Koch type robots are supported with this option.
# 1. Single task dataset
If your dataset contains a single task, you can simply provide it directly via the CLI with the
'--single-task' option.
Examples:
```bash
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
--repo-id lerobot/aloha_sim_insertion_human_image \
--single-task "Insert the peg into the socket." \
--robot-config lerobot/configs/robot/aloha.yaml \
--local-dir data
```
```bash
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
--repo-id aliberts/koch_tutorial \
--single-task "Pick the Lego block and drop it in the box on the right." \
--robot-config lerobot/configs/robot/koch.yaml \
--local-dir data
```
# 2. Single task episodes
If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
- If your dataset already contains a language instruction column in its parquet file, you can simply provide
this column's name with the '--tasks-col' arg.
Example:
```bash
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
--repo-id lerobot/stanford_kuka_multimodal_dataset \
--tasks-col "language_instruction" \
--local-dir data
```
- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
'--tasks-path' arg. This file should have the following structure where keys correspond to each
episode_index in the dataset, and values are the language instruction for that episode.
Example:
```json
{
"0": "Do something",
"1": "Do something else",
"2": "Do something",
"3": "Go there",
...
}
```
# 3. Multi task episodes
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
parquet file, and you must provide this column's name with the '--tasks-col' arg.
Example:
```bash
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
--repo-id lerobot/stanford_kuka_multimodal_dataset \
--tasks-col "language_instruction" \
--local-dir data
```
"""
import argparse
import contextlib
import filecmp
import json
import logging
import math
import shutil
import subprocess
import tempfile
from pathlib import Path
import datasets
import pyarrow.compute as pc
import pyarrow.parquet as pq
import torch
from datasets import Dataset
from huggingface_hub import HfApi
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
from safetensors.torch import load_file
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH,
EPISODES_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
create_branch,
create_lerobot_dataset_card,
flatten_dict,
get_hub_safe_version,
load_json,
unflatten_dict,
write_json,
write_jsonlines,
)
from lerobot.common.datasets.video_utils import (
VideoFrame, # noqa: F401
get_image_pixel_channels,
get_video_info,
)
from lerobot.common.utils.utils import init_hydra_config
V16 = "v1.6"
V20 = "v2.0"
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
V1_INFO_PATH = "meta_data/info.json"
V1_STATS_PATH = "meta_data/stats.safetensors"
def parse_robot_config(
config_path: Path, config_overrides: list[str] | None = None
) -> tuple[str, dict]:
robot_cfg = init_hydra_config(config_path, config_overrides)
if robot_cfg["robot_type"] in ["aloha", "koch"]:
state_names = [
f"{arm}_{motor}" if len(robot_cfg["follower_arms"]) > 1 else motor
for arm in robot_cfg["follower_arms"]
for motor in robot_cfg["follower_arms"][arm]["motors"]
]
action_names = [
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
f"{arm}_{motor}" if len(robot_cfg["leader_arms"]) > 1 else motor
for arm in robot_cfg["leader_arms"]
for motor in robot_cfg["leader_arms"][arm]["motors"]
]
# elif robot_cfg["robot_type"] == "stretch3": TODO
else:
raise NotImplementedError(
"Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()."
)
return {
"robot_type": robot_cfg["robot_type"],
"names": {
"observation.state": state_names,
"observation.effort": state_names,
"action": action_names,
},
}
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
safetensor_path = v1_dir / V1_STATS_PATH
stats = load_file(safetensor_path)
serialized_stats = {key: value.tolist() for key, value in stats.items()}
serialized_stats = unflatten_dict(serialized_stats)
json_path = v2_dir / STATS_PATH
json_path.parent.mkdir(exist_ok=True, parents=True)
with open(json_path, "w") as f:
json.dump(serialized_stats, f, indent=4)
# Sanity check
with open(json_path) as f:
stats_json = json.load(f)
stats_json = flatten_dict(stats_json)
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
for key in stats:
torch.testing.assert_close(stats_json[key], stats[key])
def get_features_from_hf_dataset(
dataset: Dataset, robot_config: dict | None = None
) -> dict[str, list]:
features = {}
for key, ft in dataset.features.items():
if isinstance(ft, datasets.Value):
dtype = ft.dtype
shape = (1,)
names = None
if isinstance(ft, datasets.Sequence):
assert isinstance(ft.feature, datasets.Value)
dtype = ft.feature.dtype
shape = (ft.length,)
motor_names = (
robot_config["names"][key]
if robot_config
else [f"motor_{i}" for i in range(ft.length)]
)
assert len(motor_names) == shape[0]
names = {"motors": motor_names}
elif isinstance(ft, datasets.Image):
dtype = "image"
image = dataset[0][key] # Assuming first row
channels = get_image_pixel_channels(image)
shape = (image.height, image.width, channels)
names = ["height", "width", "channel"]
elif ft._type == "VideoFrame":
dtype = "video"
shape = None # Add shape later
names = ["height", "width", "channel"]
features[key] = {
"dtype": dtype,
"shape": shape,
"names": names,
}
return features
def add_task_index_by_episodes(
dataset: Dataset, tasks_by_episodes: dict
) -> tuple[Dataset, list[str]]:
df = dataset.to_pandas()
tasks = list(set(tasks_by_episodes.values()))
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
episodes_to_task_index = {
ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()
}
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
features = dataset.features
features["task_index"] = datasets.Value(dtype="int64")
dataset = Dataset.from_pandas(df, features=features, split="train")
return dataset, tasks
def add_task_index_from_tasks_col(
dataset: Dataset, tasks_col: str
) -> tuple[Dataset, dict[str, list[str]], list[str]]:
df = dataset.to_pandas()
# HACK: This is to clean some of the instructions in our version of Open X datasets
prefix_to_clean = "tf.Tensor(b'"
suffix_to_clean = "', shape=(), dtype=string)"
df[tasks_col] = (
df[tasks_col]
.str.removeprefix(prefix_to_clean)
.str.removesuffix(suffix_to_clean)
)
# Create task_index col
tasks_by_episode = (
df.groupby("episode_index")[tasks_col]
.unique()
.apply(lambda x: x.tolist())
.to_dict()
)
tasks = df[tasks_col].unique().tolist()
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
# Build the dataset back from df
features = dataset.features
features["task_index"] = datasets.Value(dtype="int64")
dataset = Dataset.from_pandas(df, features=features, split="train")
dataset = dataset.remove_columns(tasks_col)
return dataset, tasks, tasks_by_episode
def split_parquet_by_episodes(
dataset: Dataset,
total_episodes: int,
total_chunks: int,
output_dir: Path,
) -> list:
table = dataset.data.table
episode_lengths = []
for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(
episode_chunk=ep_chunk
)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
episode_chunk=ep_chunk, episode_index=ep_idx
)
pq.write_table(ep_table, output_file)
return episode_lengths
def move_videos(
repo_id: str,
video_keys: list[str],
total_episodes: int,
total_chunks: int,
work_dir: Path,
clean_gittatributes: Path,
branch: str = "main",
) -> None:
"""
HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
commands to fetch git lfs video files references to move them into subdirectories without having to
actually download them.
"""
_lfs_clone(repo_id, work_dir, branch)
videos_moved = False
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
if len(video_files) == 0:
video_files = [
str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")
]
videos_moved = True # Videos have already been moved
assert len(video_files) == total_episodes * len(video_keys)
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
current_gittatributes = work_dir / ".gitattributes"
if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False):
fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes)
if lfs_untracked_videos:
fix_lfs_video_files_tracking(work_dir, video_files)
if videos_moved:
return
video_dirs = sorted(work_dir.glob("videos*/"))
for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
for vid_key in video_keys:
chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
episode_chunk=ep_chunk, video_key=vid_key
)
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
target_path = DEFAULT_VIDEO_PATH.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
)
video_file = V1_VIDEO_FILE.format(
video_key=vid_key, episode_index=ep_idx
)
if len(video_dirs) == 1:
video_path = video_dirs[0] / video_file
else:
for dir in video_dirs:
if (dir / video_file).is_file():
video_path = dir / video_file
break
video_path.rename(work_dir / target_path)
commit_message = "Move video files into chunk subdirectories"
subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_lfs_video_files_tracking(
work_dir: Path, lfs_untracked_videos: list[str]
) -> None:
"""
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
there's no other option than to download the actual files and reupload them with lfs tracking.
"""
for i in range(0, len(lfs_untracked_videos), 100):
files = lfs_untracked_videos[i : i + 100]
try:
subprocess.run(
["git", "rm", "--cached", *files],
cwd=work_dir,
capture_output=True,
check=True,
)
except subprocess.CalledProcessError as e:
print("git rm --cached ERROR:")
print(e.stderr)
subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
commit_message = "Track video files with git lfs"
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_gitattributes(
work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path
) -> None:
shutil.copyfile(clean_gittatributes, current_gittatributes)
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
subprocess.run(
["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True
)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
repo_url = f"https://huggingface.co/datasets/{repo_id}"
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
subprocess.run(
[
"git",
"clone",
"--branch",
branch,
"--single-branch",
"--depth",
"1",
repo_url,
str(work_dir),
],
check=True,
env=env,
)
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
lfs_tracked_files = subprocess.run(
["git", "lfs", "ls-files", "-n"],
cwd=work_dir,
capture_output=True,
text=True,
check=True,
)
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
return [f for f in video_files if f not in lfs_tracked_files]
def get_videos_info(
repo_id: str, local_dir: Path, video_keys: list[str], branch: str
) -> dict:
# Assumes first episode
video_files = [
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
for vid_key in video_keys
]
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=local_dir,
revision=branch,
allow_patterns=video_files,
)
videos_info_dict = {}
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
return videos_info_dict
def convert_dataset(
repo_id: str,
local_dir: Path,
single_task: str | None = None,
tasks_path: Path | None = None,
tasks_col: Path | None = None,
robot_config: dict | None = None,
test_branch: str | None = None,
**card_kwargs,
):
v1 = get_hub_safe_version(repo_id, V16)
v1x_dir = local_dir / V16 / repo_id
v20_dir = local_dir / V20 / repo_id
v1x_dir.mkdir(parents=True, exist_ok=True)
v20_dir.mkdir(parents=True, exist_ok=True)
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id,
repo_type="dataset",
revision=v1,
local_dir=v1x_dir,
ignore_patterns="videos*/",
)
branch = "main"
if test_branch:
branch = test_branch
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
features = get_features_from_hf_dataset(dataset, robot_config)
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
if single_task and "language_instruction" in dataset.column_names:
logging.warning(
"'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
)
single_task = None
tasks_col = "language_instruction"
# Episodes & chunks
episode_indices = sorted(dataset.unique("episode_index"))
total_episodes = len(episode_indices)
assert episode_indices == list(range(total_episodes))
total_videos = total_episodes * len(video_keys)
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
total_chunks += 1
# Tasks
if single_task:
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
}
elif tasks_path:
tasks_by_episodes = load_json(tasks_path)
tasks_by_episodes = {
int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()
}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
}
elif tasks_col:
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(
dataset, tasks_col
)
else:
raise ValueError
assert set(tasks) == {
task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks
}
tasks = [
{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)
]
write_jsonlines(tasks, v20_dir / TASKS_PATH)
features["task_index"] = {
"dtype": "int64",
"shape": (1,),
"names": None,
}
# Videos
if video_keys:
assert metadata_v1.get("video", False)
dataset = dataset.remove_columns(video_keys)
clean_gitattr = Path(
hub_api.hf_hub_download(
repo_id=GITATTRIBUTES_REF,
repo_type="dataset",
local_dir=local_dir,
filename=".gitattributes",
)
).absolute()
with tempfile.TemporaryDirectory() as tmp_video_dir:
move_videos(
repo_id,
video_keys,
total_episodes,
total_chunks,
Path(tmp_video_dir),
clean_gitattr,
branch,
)
videos_info = get_videos_info(
repo_id, v1x_dir, video_keys=video_keys, branch=branch
)
for key in video_keys:
features[key]["shape"] = (
videos_info[key].pop("video.height"),
videos_info[key].pop("video.width"),
videos_info[key].pop("video.channels"),
)
features[key]["video_info"] = videos_info[key]
assert math.isclose(
videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3
)
if "encoding" in metadata_v1:
assert (
videos_info[key]["video.pix_fmt"]
== metadata_v1["encoding"]["pix_fmt"]
)
else:
assert metadata_v1.get("video", 0) == 0
videos_info = None
# Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(
dataset, total_episodes, total_chunks, v20_dir
)
if robot_config is not None:
robot_type = robot_config["robot_type"]
repo_tags = [robot_type]
else:
robot_type = "unknown"
repo_tags = None
# Episodes
episodes = [
{
"episode_index": ep_idx,
"tasks": tasks_by_episodes[ep_idx],
"length": episode_lengths[ep_idx],
}
for ep_idx in episode_indices
]
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
# Assemble metadata v2.0
metadata_v2_0 = {
"codebase_version": V20,
"robot_type": robot_type,
"total_episodes": total_episodes,
"total_frames": len(dataset),
"total_tasks": len(tasks),
"total_videos": total_videos,
"total_chunks": total_chunks,
"chunks_size": DEFAULT_CHUNK_SIZE,
"fps": metadata_v1["fps"],
"splits": {"train": f"0:{total_episodes}"},
"data_path": DEFAULT_PARQUET_PATH,
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
"features": features,
}
write_json(metadata_v2_0, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir)
card = create_lerobot_dataset_card(
tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(
repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(
repo_id=repo_id,
path_in_repo="meta_data",
repo_type="dataset",
revision=branch,
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(
repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch
)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="data",
folder_path=v20_dir / "data",
repo_type="dataset",
revision=branch,
)
hub_api.upload_folder(
repo_id=repo_id,
path_in_repo="meta",
folder_path=v20_dir / "meta",
repo_type="dataset",
revision=branch,
)
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
if not test_branch:
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
def main():
parser = argparse.ArgumentParser()
task_args = parser.add_mutually_exclusive_group(required=True)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
task_args.add_argument(
"--single-task",
type=str,
help="A short but accurate description of the single task performed in the dataset.",
)
task_args.add_argument(
"--tasks-col",
type=str,
help="The name of the column containing language instructions",
)
task_args.add_argument(
"--tasks-path",
type=Path,
help="The path to a .json file containing one language instruction for each episode_index",
)
parser.add_argument(
"--robot-config",
type=Path,
default=None,
help="Path to the robot's config yaml the dataset during conversion.",
)
parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override the robot config values (use dots for.nested=overrides)",
)
parser.add_argument(
"--local-dir",
type=Path,
default=None,
help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
)
parser.add_argument(
"--license",
type=str,
default="apache-2.0",
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
)
parser.add_argument(
"--test-branch",
type=str,
default=None,
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
)
args = parser.parse_args()
if not args.local_dir:
args.local_dir = Path("/tmp/lerobot_dataset_v2")
robot_config = (
parse_robot_config(args.robot_config, args.robot_overrides)
if args.robot_config
else None
)
del args.robot_config, args.robot_overrides
convert_dataset(**vars(args), robot_config=robot_config)
if __name__ == "__main__":
main()

View File

@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import subprocess
import warnings
@@ -26,11 +25,47 @@ import pyarrow as pa
import torch
import torchvision
from datasets.features.features import register_feature
from PIL import Image
def load_from_videos(
item: dict[str, torch.Tensor],
video_frame_keys: list[str],
videos_dir: Path,
tolerance_s: float,
backend: str = "pyav",
):
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
This probably happens because a memory reference to the video loader is created in the main process and a
subprocess fails to access it.
"""
# since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
data_dir = videos_dir.parent
for key in video_frame_keys:
if isinstance(item[key], list):
# load multiple frames at once (expected when delta_timestamps is not None)
timestamps = [frame["timestamp"] for frame in item[key]]
paths = [frame["path"] for frame in item[key]]
if len(set(paths)) > 1:
raise NotImplementedError("All video paths are expected to be the same for now.")
video_path = data_dir / paths[0]
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
item[key] = frames
else:
# load one frame
timestamps = [item[key]["timestamp"]]
video_path = data_dir / item[key]["path"]
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
item[key] = frames[0]
return item
def decode_video_frames_torchvision(
video_path: Path | str,
video_path: str,
timestamps: list[float],
tolerance_s: float,
backend: str = "pyav",
@@ -128,8 +163,8 @@ def decode_video_frames_torchvision(
def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,
imgs_dir: Path,
video_path: Path,
fps: int,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
@@ -212,110 +247,3 @@ with warnings.catch_warnings():
)
# to make VideoFrame available in HuggingFace `datasets`
register_feature(VideoFrame, "VideoFrame")
def get_audio_info(video_path: Path | str) -> dict:
ffprobe_audio_cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"a:0",
"-show_entries",
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
"-of",
"json",
str(video_path),
]
result = subprocess.run(
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
info = json.loads(result.stdout)
audio_stream_info = info["streams"][0] if info.get("streams") else None
if audio_stream_info is None:
return {"has_audio": False}
# Return the information, defaulting to None if no audio stream is present
return {
"has_audio": True,
"audio.channels": audio_stream_info.get("channels", None),
"audio.codec": audio_stream_info.get("codec_name", None),
"audio.bit_rate": int(audio_stream_info["bit_rate"])
if audio_stream_info.get("bit_rate")
else None,
"audio.sample_rate": int(audio_stream_info["sample_rate"])
if audio_stream_info.get("sample_rate")
else None,
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
}
def get_video_info(video_path: Path | str) -> dict:
ffprobe_video_cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"v:0",
"-show_entries",
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
"-of",
"json",
str(video_path),
]
result = subprocess.run(
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
info = json.loads(result.stdout)
video_stream_info = info["streams"][0]
# Calculate fps from r_frame_rate
r_frame_rate = video_stream_info["r_frame_rate"]
num, denom = map(int, r_frame_rate.split("/"))
fps = num / denom
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
video_info = {
"video.fps": fps,
"video.height": video_stream_info["height"],
"video.width": video_stream_info["width"],
"video.channels": pixel_channels,
"video.codec": video_stream_info["codec_name"],
"video.pix_fmt": video_stream_info["pix_fmt"],
"video.is_depth_map": False,
**get_audio_info(video_path),
}
return video_info
def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
return 4
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
return 3
else:
raise ValueError("Unknown format")
def get_image_pixel_channels(image: Image):
if image.mode == "L":
return 1 # Grayscale
elif image.mode == "LA":
return 2 # Grayscale + Alpha
elif image.mode == "RGB":
return 3 # RGB
elif image.mode == "RGBA":
return 4 # RGBA
else:
raise ValueError("Unknown format")

View File

@@ -14,13 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from collections import deque
import gymnasium as gym
import numpy as np
import torch
from omegaconf import DictConfig
# from mani_skill.utils import common
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
@@ -34,12 +30,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
if cfg.env.name == "real_world":
return
if "maniskill" in cfg.env.name:
env = make_maniskill_env(
cfg, n_envs if n_envs is not None else cfg.eval.batch_size
)
return env
package_name = f"gym_{cfg.env.name}"
try:
@@ -57,11 +47,7 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
# batched version of the env that returns an observation of shape (b, c)
env_cls = (
gym.vector.AsyncVectorEnv
if cfg.eval.use_async_envs
else gym.vector.SyncVectorEnv
)
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
env = env_cls(
[
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
@@ -70,99 +56,3 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
)
return env
def make_maniskill_env(
cfg: DictConfig, n_envs: int | None = None
) -> gym.vector.VectorEnv | None:
"""Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
env = gym.make(
cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
num_envs=n_envs,
)
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
env = ManiSkillVectorEnv(env, ignore_terminations=True)
# state should have the size of 25
# env = ConvertToLeRobotEnv(env, n_envs)
# env = PixelWrapper(cfg, env, n_envs)
env._max_episode_steps = env.max_episode_steps = (
50 # gym_utils.find_max_episode_steps_value(env)
)
env.unwrapped.metadata["render_fps"] = 20
return env
class PixelWrapper(gym.Wrapper):
"""
Wrapper for pixel observations. Works with Maniskill vectorized environments
"""
def __init__(self, cfg, env, num_envs, num_frames=3):
super().__init__(env)
self.cfg = cfg
self.env = env
self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
dtype=np.uint8,
)
self._frames = deque([], maxlen=num_frames)
self._render_size = cfg.env.render_size
def _get_obs(self, obs):
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
self._frames.append(frame)
return {
"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(
self.env.device
)
}
def reset(self, seed):
obs, info = self.env.reset() # (seed=seed)
for _ in range(self._frames.maxlen):
obs_frames = self._get_obs(obs)
return obs_frames, info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info
# TODO: Remove this
class ConvertToLeRobotEnv(gym.Wrapper):
def __init__(self, env, num_envs):
super().__init__(env)
def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options={})
return self._get_obs(obs), info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info
def _get_obs(self, observation):
sensor_data = observation.pop("sensor_data")
del observation["sensor_param"]
images = []
for cam_data in sensor_data.values():
images.append(cam_data["rgb"])
images = torch.concat(images, axis=-1)
# flatten the rest of the data which should just be state data
observation = common.flatten_state_dict(
observation, use_torch=True, device=self.base_env.device
)
ret = dict()
ret["state"] = observation
ret["pixels"] = images
return ret

View File

@@ -28,32 +28,28 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
"""
# map to expected inputs for the policy
return_observations = {}
# TODO: You have to merge all tensors from agent key and extra key
# You don't keep sensor param key in the observation
# And you keep sensor data rgb
for key, img in observations.items():
if "images" not in key:
continue
if "pixels" in observations:
if isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
else:
imgs = {"observation.image": observations["pixels"]}
if img.ndim == 3:
img = img.unsqueeze(0)
# sanity check that images are channel last
_, h, w, c = img.shape
assert (
c < h and c < w
), f"expect channel last images, but instead got {img.shape=}"
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
return_observations[key] = img
# obs state agent qpos and qvel
# image
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
return_observations[imgkey] = img
if "environment_state" in observations:
return_observations["observation.environment_state"] = torch.from_numpy(
@@ -62,43 +58,5 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
# requirement for "agent_pos"
# return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return_observations["observation.state"] = observations["observation.state"].float()
return return_observations
def preprocess_maniskill_observation(
observations: dict[str, np.ndarray],
) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
# map to expected inputs for the policy
return_observations = {}
# TODO: You have to merge all tensors from agent key and extra key
# You don't keep sensor param key in the observation
# And you keep sensor data rgb
q_pos = observations["agent"]["qpos"]
q_vel = observations["agent"]["qvel"]
tcp_pos = observations["extra"]["tcp_pose"]
img = observations["sensor_data"]["base_camera"]["rgb"]
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
return_observations["observation.image"] = img
return_observations["observation.state"] = state
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return return_observations

View File

@@ -25,7 +25,6 @@ from glob import glob
from pathlib import Path
import torch
import wandb
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
@@ -84,9 +83,7 @@ class Logger:
pretrained_model_dir_name = "pretrained_model"
training_state_file_name = "training_state.pth"
def __init__(
self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None
):
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
"""
Args:
log_dir: The directory to save all logs and training outputs to.
@@ -106,12 +103,12 @@ class Logger:
enable_wandb = cfg.get("wandb", {}).get("enable", False)
run_offline = not enable_wandb or not project
if run_offline:
logging.info(
colored("Logs will be saved locally.", "yellow", attrs=["bold"])
)
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
self._wandb = None
else:
os.environ["WANDB_SILENT"] = "true"
import wandb
wandb_run_id = None
if cfg.resume:
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
@@ -131,12 +128,8 @@ class Logger:
job_type="train_eval",
resume="must" if cfg.resume else None,
)
# Handle custom step key for rl asynchronous training.
self._wandb_custom_step_key: set[str] | None = None
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
)
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
@classmethod
@@ -157,9 +150,7 @@ class Logger:
"""
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
def save_model(
self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None
):
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
"""Save the weights of the Policy model using PyTorchModelHubMixin.
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
@@ -182,32 +173,18 @@ class Logger:
self,
save_dir: Path,
train_step: int,
optimizer: Optimizer | dict,
optimizer: Optimizer,
scheduler: LRScheduler | None,
interaction_step: int | None = None,
):
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
All of these are saved as "training_state.pth" under the checkpoint directory.
"""
# In Sac, for example, we have a dictionary of torch.optim.Optimizer
if type(optimizer) is dict:
optimizer_state_dict = {}
for k in optimizer:
optimizer_state_dict[k] = optimizer[k].state_dict()
else:
optimizer_state_dict = optimizer.state_dict()
training_state = {
"step": train_step,
"optimizer": optimizer_state_dict,
"optimizer": optimizer.state_dict(),
**get_global_random_state(),
}
# Interaction step is related to the distributed training code
# In that setup, we have two kinds of steps, the online step of the env and the optimization step
# We need to save both in order to resume the optimization properly and not break the logs dependant on the interaction step
if interaction_step is not None:
training_state["interaction_step"] = interaction_step
if scheduler is not None:
training_state["scheduler"] = scheduler.state_dict()
torch.save(training_state, save_dir / self.training_state_file_name)
@@ -219,7 +196,6 @@ class Logger:
optimizer: Optimizer,
scheduler: LRScheduler | None,
identifier: str,
interaction_step: int | None = None,
):
"""Checkpoint the model weights and the training state."""
checkpoint_dir = self.checkpoints_dir / str(identifier)
@@ -229,34 +205,18 @@ class Logger:
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
)
self.save_model(
checkpoint_dir / self.pretrained_model_dir_name,
policy,
wandb_artifact_name=wandb_artifact_name,
)
self.save_training_state(
checkpoint_dir, train_step, optimizer, scheduler, interaction_step
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
)
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
def load_last_training_state(
self, optimizer: Optimizer | dict, scheduler: LRScheduler | None
) -> int:
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
"""
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
random state, and return the global training step.
"""
training_state = torch.load(
self.last_checkpoint_dir / self.training_state_file_name
)
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
if type(training_state["optimizer"]) is dict:
assert set(training_state["optimizer"].keys()) == set(
optimizer.keys()
), "Optimizer dictionaries do not have the same keys during resume!"
for k, v in training_state["optimizer"].items():
optimizer[k].load_state_dict(v)
else:
optimizer.load_state_dict(training_state["optimizer"])
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
optimizer.load_state_dict(training_state["optimizer"])
if scheduler is not None:
scheduler.load_state_dict(training_state["scheduler"])
elif "scheduler" in training_state:
@@ -264,63 +224,20 @@ class Logger:
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
)
# Small hack to get the expected keys: use `get_global_random_state`.
set_global_random_state(
{k: training_state[k] for k in get_global_random_state()}
)
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"]
def log_dict(
self,
d,
step: int | None = None,
mode="train",
custom_step_key: str | None = None,
):
"""Log a dictionary of metrics to WandB."""
def log_dict(self, d, step, mode="train"):
assert mode in {"train", "eval"}
# TODO(alexander-soare): Add local text log.
if step is None and custom_step_key is None:
raise ValueError("Either step or custom_step_key must be provided.")
if self._wandb is not None:
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
# increases with each wandb.log call, but in the case of asynchronous RL for example,
# multiple time steps is possible for example, the interaction step with the environment,
# the training step, the evaluation step, etc. So we need to define a custom step key
# to log the correct step for each metric.
if custom_step_key is not None:
if self._wandb_custom_step_key is None:
self._wandb_custom_step_key = set()
new_custom_key = f"{mode}/{custom_step_key}"
if new_custom_key not in self._wandb_custom_step_key:
self._wandb_custom_step_key.add(new_custom_key)
self._wandb.define_metric(new_custom_key, hidden=True)
for k, v in d.items():
if not isinstance(v, (int, float, str, wandb.Table)):
if not isinstance(v, (int, float, str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
# Do not log the custom step key itself.
if (
self._wandb_custom_step_key is not None
and k in self._wandb_custom_step_key
):
continue
if custom_step_key is not None:
value_custom_step = d[custom_step_key]
self._wandb.log(
{
f"{mode}/{k}": v,
f"{mode}/{custom_step_key}": value_custom_step,
}
)
continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
assert mode in {"train", "eval"}

View File

@@ -168,6 +168,4 @@ class ACTConfig:
not any(k.startswith("observation.image") for k in self.input_shapes)
and "observation.environment_state" not in self.input_shapes
):
raise ValueError(
"You must provide at least one image or the environment state among the inputs."
)
raise ValueError("You must provide at least one image or the environment state among the inputs.")

View File

@@ -81,14 +81,10 @@ class ACTPolicy(
self.model = ACT(config)
self.expected_image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(
config.temporal_ensemble_coeff, config.chunk_size
)
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.reset()
@@ -111,12 +107,8 @@ class ACTPolicy(
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[k] for k in self.expected_image_keys], dim=-4
)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
@@ -143,18 +135,13 @@ class ACTPolicy(
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[k] for k in self.expected_image_keys], dim=-4
)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none")
* ~batch["action_is_pad"].unsqueeze(-1)
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {"l1_loss": l1_loss.item()}
@@ -164,12 +151,7 @@ class ACTPolicy(
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(
-0.5
* (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())
)
.sum(-1)
.mean()
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
@@ -223,9 +205,7 @@ class ACTTemporalEnsembler:
```
"""
self.chunk_size = chunk_size
self.ensemble_weights = torch.exp(
-temporal_ensemble_coeff * torch.arange(chunk_size)
)
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.reset()
@@ -241,9 +221,7 @@ class ACTTemporalEnsembler:
time steps, and pop/return the next batch of actions in the sequence.
"""
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(
device=actions.device
)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
if self.ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
@@ -251,34 +229,19 @@ class ACTTemporalEnsembler:
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later.
self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1),
dtype=torch.long,
device=self.ensembled_actions.device,
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
)
else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the online update for those entries.
self.ensembled_actions *= self.ensemble_weights_cumsum[
self.ensembled_actions_count - 1
]
self.ensembled_actions += (
actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
)
self.ensembled_actions /= self.ensemble_weights_cumsum[
self.ensembled_actions_count
]
self.ensembled_actions_count = torch.clamp(
self.ensembled_actions_count + 1, max=self.chunk_size
)
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
# The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat(
[self.ensembled_actions, actions[:, -1:]], dim=1
)
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
self.ensembled_actions_count = torch.cat(
[
self.ensembled_actions_count,
torch.ones_like(self.ensembled_actions_count[-1:]),
]
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
)
# "Consume" the first action.
action, self.ensembled_actions, self.ensembled_actions_count = (
@@ -330,9 +293,7 @@ class ACT(nn.Module):
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.use_robot_state = "observation.state" in config.input_shapes
self.use_images = any(
k.startswith("observation.image") for k in config.input_shapes
)
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
self.use_env_state = "observation.environment_state" in config.input_shapes
if self.config.use_vae:
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
@@ -347,9 +308,7 @@ class ACT(nn.Module):
config.output_shapes["action"][0], config.dim_model
)
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(
config.dim_model, config.latent_dim * 2
)
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
@@ -357,28 +316,20 @@ class ACT(nn.Module):
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(
num_input_token_encoder, config.dim_model
).unsqueeze(0),
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
)
# Backbone for image feature extraction.
if self.use_images:
backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[
False,
False,
config.replace_final_stride_with_dilation,
],
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
weights=config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d,
)
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
# feature map).
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(
backbone_model, return_layers={"layer4": "feature_map"}
)
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config)
@@ -392,8 +343,7 @@ class ACT(nn.Module):
)
if self.use_env_state:
self.encoder_env_state_input_proj = nn.Linear(
config.input_shapes["observation.environment_state"][0],
config.dim_model,
config.input_shapes["observation.environment_state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
if self.use_images:
@@ -408,18 +358,14 @@ class ACT(nn.Module):
n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.use_images:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(
config.dim_model // 2
)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(
config.dim_model, config.output_shapes["action"][0]
)
self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0])
self._reset_parameters()
@@ -429,9 +375,7 @@ class ACT(nn.Module):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(
self, batch: dict[str, Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure:
@@ -468,20 +412,12 @@ class ACT(nn.Module):
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.use_robot_state:
robot_state_embed = self.vae_encoder_robot_state_input_proj(
batch["observation.state"]
)
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(
batch["action"]
) # (B, S, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.use_robot_state:
vae_encoder_input = [
cls_embed,
robot_state_embed,
action_embed,
] # (B, S+2, D)
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
@@ -519,26 +455,20 @@ class ACT(nn.Module):
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros(
[batch_size, self.config.latent_dim], dtype=torch.float32
).to(batch["observation.state"].device)
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
# Prepare transformer encoder inputs.
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(
self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)
)
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
# Robot state token.
if self.use_robot_state:
encoder_in_tokens.append(
self.encoder_robot_state_input_proj(batch["observation.state"])
)
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
# Environment state token.
if self.use_env_state:
encoder_in_tokens.append(
self.encoder_env_state_input_proj(
batch["observation.environment_state"]
)
self.encoder_env_state_input_proj(batch["observation.environment_state"])
)
# Camera observation features and positional embeddings.
@@ -547,29 +477,19 @@ class ACT(nn.Module):
all_cam_pos_embeds = []
for cam_index in range(batch["observation.images"].shape[-4]):
cam_features = self.backbone(batch["observation.images"][:, cam_index])[
"feature_map"
]
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
# buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(
dtype=cam_features.dtype
)
cam_features = self.encoder_img_feat_input_proj(
cam_features
) # (B, C, h, w)
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
# and move to (sequence, batch, dim).
all_cam_features = torch.cat(all_cam_features, axis=-1)
encoder_in_tokens.extend(
einops.rearrange(all_cam_features, "b c h w -> (h w) b c")
)
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
encoder_in_pos_embed.extend(
einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c")
)
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
# Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
@@ -604,21 +524,12 @@ class ACTEncoder(nn.Module):
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
super().__init__()
self.is_vae_encoder = is_vae_encoder
num_layers = (
config.n_vae_encoder_layers
if self.is_vae_encoder
else config.n_encoder_layers
)
self.layers = nn.ModuleList(
[ACTEncoderLayer(config) for _ in range(num_layers)]
)
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward(
self,
x: Tensor,
pos_embed: Tensor | None = None,
key_padding_mask: Tensor | None = None,
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
for layer in self.layers:
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
@@ -629,9 +540,7 @@ class ACTEncoder(nn.Module):
class ACTEncoderLayer(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
# Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@@ -646,9 +555,7 @@ class ACTEncoderLayer(nn.Module):
self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm
def forward(
self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
skip = x
if self.pre_norm:
x = self.norm1(x)
@@ -673,9 +580,7 @@ class ACTDecoder(nn.Module):
def __init__(self, config: ACTConfig):
"""Convenience module for running multiple decoder layers followed by normalization."""
super().__init__()
self.layers = nn.ModuleList(
[ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]
)
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
self.norm = nn.LayerNorm(config.dim_model)
def forward(
@@ -687,10 +592,7 @@ class ACTDecoder(nn.Module):
) -> Tensor:
for layer in self.layers:
x = layer(
x,
encoder_out,
decoder_pos_embed=decoder_pos_embed,
encoder_pos_embed=encoder_pos_embed,
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
)
if self.norm is not None:
x = self.norm(x)
@@ -700,12 +602,8 @@ class ACTDecoder(nn.Module):
class ACTDecoderLayer(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
self.multihead_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
# Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@@ -746,9 +644,7 @@ class ACTDecoderLayer(nn.Module):
if self.pre_norm:
x = self.norm1(x)
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
x = self.self_attn(q, k, value=x)[
0
] # select just the output, not the attention weights
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
x = skip + self.dropout1(x)
if self.pre_norm:
skip = x
@@ -785,14 +681,9 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso
"""
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (hid_j // 2) / dimension)
for hid_j in range(dimension)
]
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(num_positions)]
)
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.from_numpy(sinusoid_table).float()
@@ -837,9 +728,7 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
inverse_frequency = self._temperature ** (
2
* (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2)
/ self.dimension
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
)
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
@@ -847,15 +736,9 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
pos_embed_x = torch.stack(
(x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1
).flatten(3)
pos_embed_y = torch.stack(
(y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1
).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(
0, 3, 1, 2
) # (1, C, H, W)
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
return pos_embed

View File

@@ -121,9 +121,7 @@ class DiffusionConfig:
"observation.state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
# Architecture / modeling.
# Vision backbone.
@@ -165,13 +163,8 @@ class DiffusionConfig:
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if (
len(image_keys) == 0
and "observation.environment_state" not in self.input_shapes
):
raise ValueError(
"You must provide at least one image or the environment state among the inputs."
)
if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if len(image_keys) > 0:
if self.crop_shape is not None:

View File

@@ -88,9 +88,7 @@ class DiffusionPolicy(
self.diffusion = DiffusionModel(config)
self.expected_image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.use_env_state = "observation.environment_state" in config.input_shapes
self.reset()
@@ -104,9 +102,7 @@ class DiffusionPolicy(
if len(self.expected_image_keys) > 0:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
if self.use_env_state:
self._queues["observation.environment_state"] = deque(
maxlen=self.config.n_obs_steps
)
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@@ -132,22 +128,14 @@ class DiffusionPolicy(
"""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[k] for k in self.expected_image_keys], dim=-4
)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {
k: torch.stack(list(self._queues[k]), dim=1)
for k in batch
if k in self._queues
}
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
@@ -162,12 +150,8 @@ class DiffusionPolicy(
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[k] for k in self.expected_image_keys], dim=-4
)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
@@ -193,9 +177,7 @@ class DiffusionModel(nn.Module):
# Build observation encoders (depending on which observations are provided).
global_cond_dim = config.input_shapes["observation.state"][0]
num_images = len(
[k for k in config.input_shapes if k.startswith("observation.image")]
)
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
self._use_images = False
self._use_env_state = False
if num_images > 0:
@@ -211,9 +193,7 @@ class DiffusionModel(nn.Module):
self._use_env_state = True
global_cond_dim += config.input_shapes["observation.environment_state"][0]
self.unet = DiffusionConditionalUnet1d(
config, global_cond_dim=global_cond_dim * config.n_obs_steps
)
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
@@ -233,21 +213,14 @@ class DiffusionModel(nn.Module):
# ========= inference ============
def conditional_sample(
self,
batch_size: int,
global_cond: Tensor | None = None,
generator: torch.Generator | None = None,
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
) -> Tensor:
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
# Sample prior.
sample = torch.randn(
size=(
batch_size,
self.config.horizon,
self.config.output_shapes["action"][0],
),
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
dtype=dtype,
device=device,
generator=generator,
@@ -263,9 +236,7 @@ class DiffusionModel(nn.Module):
global_cond=global_cond,
)
# Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(
model_output, t, sample, generator=generator
).prev_sample
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
return sample
@@ -277,39 +248,27 @@ class DiffusionModel(nn.Module):
if self._use_images:
if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange(
batch["observation.images"], "b s n ... -> n (b s) ..."
)
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
img_features_list = torch.cat(
[
encoder(images)
for encoder, images in zip(
self.rgb_encoder, images_per_camera, strict=True
)
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
]
)
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features_list,
"(n b s) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
)
else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
img_features = self.rgb_encoder(
einops.rearrange(
batch["observation.images"], "b s n ... -> (b s n) ..."
)
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features,
"(b s n) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
)
global_cond_feats.append(img_features)
@@ -395,9 +354,7 @@ class DiffusionModel(nn.Module):
elif self.config.prediction_type == "sample":
target = batch["action"]
else:
raise ValueError(
f"Unsupported prediction type {self.config.prediction_type}"
)
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
loss = F.mse_loss(pred, target, reduction="none")
@@ -457,9 +414,7 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
@@ -501,9 +456,7 @@ class DiffusionRgbEncoder(nn.Module):
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(
config.crop_shape
)
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
@@ -524,9 +477,7 @@ class DiffusionRgbEncoder(nn.Module):
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# Set up pooling and final layers.
@@ -534,25 +485,17 @@ class DiffusionRgbEncoder(nn.Module):
# The dummy input should take the number of image channels from `config.input_shapes` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`.
image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: we have a check in the config class to make sure all images have the same shape.
image_key = image_keys[0]
dummy_input_h_w = (
config.crop_shape
if config.crop_shape is not None
else config.input_shapes[image_key][1:]
)
dummy_input = torch.zeros(
size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
)
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])
self.pool = SpatialSoftmax(
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
)
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
@@ -579,9 +522,7 @@ class DiffusionRgbEncoder(nn.Module):
def _replace_submodules(
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
Args:
@@ -594,11 +535,7 @@ def _replace_submodules(
if predicate(root_module):
return func(root_module)
replace_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
@@ -613,9 +550,7 @@ def _replace_submodules(
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
return root_module
@@ -643,9 +578,7 @@ class DiffusionConv1dBlock(nn.Module):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
),
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
@@ -668,13 +601,9 @@ class DiffusionConditionalUnet1d(nn.Module):
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
nn.Linear(
config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4
),
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
nn.Mish(),
nn.Linear(
config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim
),
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
)
# The FiLM conditioning dimension.
@@ -699,16 +628,10 @@ class DiffusionConditionalUnet1d(nn.Module):
self.down_modules.append(
nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(
dim_in, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1)
if not is_last
else nn.Identity(),
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
]
)
)
@@ -717,14 +640,10 @@ class DiffusionConditionalUnet1d(nn.Module):
self.mid_modules = nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(
config.down_dims[-1],
config.down_dims[-1],
**common_res_block_kwargs,
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
config.down_dims[-1],
config.down_dims[-1],
**common_res_block_kwargs,
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
),
]
)
@@ -737,24 +656,16 @@ class DiffusionConditionalUnet1d(nn.Module):
nn.ModuleList(
[
# dim_in * 2, because it takes the encoder's skip connection as well
DiffusionConditionalResidualBlock1d(
dim_in * 2, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1)
if not is_last
else nn.Identity(),
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
]
)
)
self.final_conv = nn.Sequential(
DiffusionConv1dBlock(
config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size
),
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1),
)
@@ -822,23 +733,17 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
self.use_film_scale_modulation = use_film_scale_modulation
self.out_channels = out_channels
self.conv1 = DiffusionConv1dBlock(
in_channels, out_channels, kernel_size, n_groups=n_groups
)
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = DiffusionConv1dBlock(
out_channels, out_channels, kernel_size, n_groups=n_groups
)
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
# A final convolution for dimension matching the residual (if needed).
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1)
if in_channels != out_channels
else nn.Identity()
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:

View File

@@ -51,10 +51,15 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy, TDMPCConfig
elif name == "tdmpc2":
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.tdmpc2.modeling_tdmpc2 import TDMPC2Policy
return TDMPC2Policy, TDMPC2Config
elif name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import (
DiffusionConfig,
)
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
return DiffusionPolicy, DiffusionConfig
@@ -68,21 +73,12 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy, VQBeTConfig
elif name == "sac":
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.sac.modeling_sac import SACPolicy
return SACPolicy, SACConfig
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
def make_policy(
hydra_cfg: DictConfig,
pretrained_policy_name_or_path: str | None = None,
dataset_stats=None,
*args,
**kwargs,
hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None
) -> Policy:
"""Make an instance of a policy class.
@@ -96,19 +92,17 @@ def make_policy(
be provided when initializing a new policy, and must not be provided when loading a pretrained
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
"""
# if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
# raise ValueError(
# "Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
# )
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
raise ValueError(
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
)
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
if pretrained_policy_name_or_path is None:
# Make a fresh policy.
# HACK: We pass *args and **kwargs to the policy constructor to allow for additional arguments
# for example device for the sac policy.
policy = policy_cls(config=policy_cfg, dataset_stats=dataset_stats)
policy = policy_cls(policy_cfg, dataset_stats)
else:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
@@ -117,9 +111,7 @@ def make_policy(
# huggingface_hub should make it possible to avoid the hack:
# https://github.com/huggingface/huggingface_hub/pull/2274.
policy = policy_cls(policy_cfg)
policy.load_state_dict(
policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict()
)
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
policy.to(get_safe_torch_device(hydra_cfg.device))

View File

@@ -1,35 +0,0 @@
import json
import os
from dataclasses import asdict, dataclass
@dataclass
class ClassifierConfig:
"""Configuration for the Classifier model."""
num_classes: int = 2
hidden_dim: int = 256
dropout_rate: float = 0.1
model_name: str = "helper2424/resnet10"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2
def save_pretrained(self, save_dir):
"""Save config to json file."""
os.makedirs(save_dir, exist_ok=True)
# Convert to dict and save as JSON
config_dict = asdict(self)
with open(os.path.join(save_dir, "config.json"), "w") as f:
json.dump(config_dict, f, indent=2)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path):
"""Load config from json file."""
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
with open(config_file) as f:
config_dict = json.load(f)
return cls(**config_dict)

View File

@@ -1,173 +0,0 @@
import logging
from typing import Optional
import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn
from .configuration_classifier import ClassifierConfig
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata."""
def __init__(
self,
logits: Tensor,
probabilities: Optional[Tensor] = None,
hidden_states: Optional[Tensor] = None,
):
self.logits = logits
self.probabilities = probabilities
self.hidden_states = hidden_states
def __repr__(self):
return (
f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})"
)
class Classifier(
nn.Module,
PyTorchModelHubMixin,
# Add Hub metadata
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "vision-classifier"],
):
"""Image classifier built on top of a pre-trained encoder."""
# Add name attribute for factory
name = "classifier"
def __init__(self, config: ClassifierConfig):
from transformers import AutoModel
super().__init__()
self.config = config
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
encoder = AutoModel.from_pretrained(
self.config.model_name, trust_remote_code=True
)
# Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"):
logging.info("Multimodal model detected - using vision encoder only")
self.encoder = encoder.vision_model
self.vision_config = encoder.config.vision_config
else:
self.encoder = encoder
self.vision_config = getattr(encoder, "config", None)
# Model type from config
self.is_cnn = self.config.model_type == "cnn"
# For CNNs, initialize backbone
if self.is_cnn:
self._setup_cnn_backbone()
self._freeze_encoder()
self._build_classifier_head()
def _setup_cnn_backbone(self):
"""Set up CNN encoder"""
if hasattr(self.encoder, "fc"):
self.feature_dim = self.encoder.fc.in_features
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
elif hasattr(self.encoder.config, "hidden_sizes"):
self.feature_dim = self.encoder.config.hidden_sizes[
-1
] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
self.encoder = self.encoder.to(self.config.device)
def _freeze_encoder(self) -> None:
"""Freeze the encoder parameters."""
for param in self.encoder.parameters():
param.requires_grad = False
def _build_classifier_head(self) -> None:
"""Initialize the classifier head architecture."""
# Get input dimension based on model type
if self.is_cnn:
input_dim = self.feature_dim
else: # Transformer models
if hasattr(self.encoder.config, "hidden_size"):
input_dim = self.encoder.config.hidden_size
else:
raise ValueError(
"Unsupported transformer architecture since hidden_size is not found"
)
self.classifier_head = nn.Sequential(
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
nn.Dropout(self.config.dropout_rate),
nn.LayerNorm(self.config.hidden_dim),
nn.ReLU(),
nn.Linear(
self.config.hidden_dim,
1 if self.config.num_classes == 2 else self.config.num_classes,
),
)
self.classifier_head = self.classifier_head.to(self.config.device)
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
"""Extract the appropriate output from the encoder."""
# Process images with the processor (handles resizing and normalization)
# processed = self.processor(
# images=x, # LeRobotDataset already provides proper tensor format
# return_tensors="pt",
# )
# processed = processed["pixel_values"].to(x.device)
processed = x
with torch.no_grad():
if self.is_cnn:
# The HF ResNet applies pooling internally
outputs = self.encoder(processed)
# Get pooled output directly
features = outputs.pooler_output
if features.dim() > 2:
features = features.squeeze(-1).squeeze(-1)
return features
else: # Transformer models
outputs = self.encoder(processed)
if (
hasattr(outputs, "pooler_output")
and outputs.pooler_output is not None
):
return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :]
def forward(self, xs: torch.Tensor) -> ClassifierOutput:
"""Forward pass of the classifier."""
# For training, we expect input to be a tensor directly from LeRobotDataset
encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs])
logits = self.classifier_head(encoder_outputs)
if self.config.num_classes == 2:
logits = logits.squeeze(-1)
probabilities = torch.sigmoid(logits)
else:
probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput(
logits=logits, probabilities=probabilities, hidden_states=encoder_outputs
)
def predict_reward(self, x, threshold=0.6):
if self.config.num_classes == 2:
probs = self.forward(x).probabilities
logging.debug(f"Predicted reward images: {probs}")
return (probs > threshold).float()
else:
return torch.argmax(self.forward(x).probabilities, dim=1)

View File

@@ -1,29 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
class HILSerlPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "hilserl"],
):
pass

View File

@@ -130,7 +130,7 @@ class Normalize(nn.Module):
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
# @torch.no_grad
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, mode in self.modes.items():
@@ -196,7 +196,7 @@ class Unnormalize(nn.Module):
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
# @torch.no_grad
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, mode in self.modes.items():

View File

@@ -1,108 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any
@dataclass
class SACConfig:
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [2],
}
)
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
"observation.environment_state": "min_max",
}
)
input_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"observation.image": {
"mean": [[0.485, 0.456, 0.406]],
"std": [[0.229, 0.224, 0.225]],
},
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"}
)
output_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"action": {"min": [-1, -1], "max": [1, 1]},
}
)
# TODO: Move it outside of the config
actor_learner_config: dict[str, str | int] = field(
default_factory=lambda: {
"learner_host": "127.0.0.1",
"learner_port": 50051,
}
)
camera_number: int = 1
storage_device: str = "cpu"
# Add type annotations for these fields:
vision_encoder_name: str | None = field(default="helper2424/resnet10")
freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32
shared_encoder: bool = True
discount: float = 0.99
temperature_init: float = 1.0
num_critics: int = 2
num_subsample_critics: int | None = None
critic_lr: float = 3e-4
actor_lr: float = 3e-4
temperature_lr: float = 3e-4
critic_target_update_weight: float = 0.005
utd_ratio: int = 1 # If you want enable utd_ratio, you need to set it to >1
state_encoder_hidden_dim: int = 256
latent_dim: int = 256
target_entropy: float | None = None
use_backup_entropy: bool = True
grad_clip_norm: float = 40.0
critic_network_kwargs: dict[str, Any] = field(
default_factory=lambda: {
"hidden_dims": [256, 256],
"activate_final": True,
"final_activation": None,
}
)
actor_network_kwargs: dict[str, Any] = field(
default_factory=lambda: {
"hidden_dims": [256, 256],
"activate_final": True,
}
)
policy_kwargs: dict[str, Any] = field(
default_factory=lambda: {
"use_tanh_squash": True,
"log_std_min": -5,
"log_std_max": 2,
"init_final": 0.05,
}
)

View File

@@ -1,981 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: (1) better device management
import math
from typing import Callable, Optional, Tuple, Union, Dict, List
from pathlib import Path
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.utils import get_device_from_parameters
class SACPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "RL", "SAC"],
):
name = "sac"
def __init__(
self,
config: SACConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
if config is None:
config = SACConfig()
self.config = config
if config.input_normalization_modes is not None:
input_normalization_params = _convert_normalization_params_to_tensor(
config.input_normalization_params
)
self.normalize_inputs = Normalize(
config.input_shapes,
config.input_normalization_modes,
input_normalization_params,
)
else:
self.normalize_inputs = nn.Identity()
output_normalization_params = _convert_normalization_params_to_tensor(
config.output_normalization_params
)
# HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
# NOTE: For images the encoder should be shared between the actor and critic
if config.shared_encoder:
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor: SACObservationEncoder = encoder_critic
else:
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
# Create a list of critic heads
critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
self.critic_ensemble = CriticEnsemble(
encoder=encoder_critic,
ensemble=critic_heads,
output_normalization=self.normalize_targets,
)
# Create target critic heads as deepcopies of the original critic heads
target_critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
self.critic_target = CriticEnsemble(
encoder=encoder_critic,
ensemble=target_critic_heads,
output_normalization=self.normalize_targets,
)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
self.actor = Policy(
encoder=encoder_actor,
network=MLP(
input_dim=encoder_actor.output_dim, **config.actor_network_kwargs
),
action_dim=config.output_shapes["action"][0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
if config.target_entropy is None:
config.target_entropy = (
-np.prod(config.output_shapes["action"][0]) / 2
) # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
# it triggers "can't optimize a non-leaf Tensor"
temperature_init = config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
self.temperature = self.log_alpha.exp().item()
def _save_pretrained(self, save_directory):
"""Custom save method to handle TensorDict properly"""
import os
import json
from dataclasses import asdict
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from safetensors.torch import save_model
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
# Save config
config_dict = asdict(self.config)
with open(os.path.join(save_directory, CONFIG_NAME), "w") as f:
json.dump(config_dict, f, indent=2)
print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}")
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: Optional[bool],
local_files_only: bool,
token: Optional[Union[str, bool]],
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
) -> "SACPolicy":
"""Custom load method to handle loading SAC policy from saved files"""
import os
import json
from pathlib import Path
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from safetensors.torch import load_model
from lerobot.common.policies.sac.configuration_sac import SACConfig
# Check if model_id is a local path or a hub model ID
if os.path.isdir(model_id):
model_path = Path(model_id)
safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE)
config_file = os.path.join(model_path, CONFIG_NAME)
else:
# Download the safetensors file from the hub
safetensors_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
# Download the config file
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except Exception:
config_file = None
# Load or create config
if config_file and os.path.exists(config_file):
# Load config from file
with open(config_file) as f:
config_dict = json.load(f)
config = SACConfig(**config_dict)
else:
# Use the provided config or create a default one
config = model_kwargs.get("config", SACConfig())
# Create a new instance with the loaded config
model = cls(config=config)
# Load state dict from safetensors file
if os.path.exists(safetensors_file):
load_model(model, filename=safetensors_file, device=map_location)
return model
def reset(self):
"""Reset the policy"""
pass
def to(self, *args, **kwargs):
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
if self.actor.fixed_std is not None:
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
super().to(*args, **kwargs)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
actions, _, _ = self.actor(batch)
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
def critic_forward(
self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
Returns:
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features)
return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_param, param in zip(
self.critic_target.parameters(),
self.critic_ensemble.parameters(),
strict=False,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def compute_loss_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
self.temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(
next_observations, next_observation_features
)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[
"action"
]
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_forward(
observations,
actions,
use_target=False,
observation_features=observation_features,
)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(1)
).sum()
return critics_loss
def compute_loss_temperature(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (
-self.log_alpha.exp() * (log_probs + self.config.target_entropy)
).mean()
return temperature_loss
def compute_loss_actor(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
self.temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, observation_features)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"]
q_preds = self.critic_forward(
observations,
actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.activate_final = activate_final
layers = []
# First layer uses input_dim
layers.append(nn.Linear(input_dim, hidden_dims[0]))
# Add activation after first layer
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[0]))
layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
# Rest of the layers
for i in range(1, len(hidden_dims)):
layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
if i + 1 < len(hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[i]))
# If we're at the final layer and a final activation is specified, use it
if (
i + 1 == len(hidden_dims)
and activate_final
and final_activation is not None
):
layers.append(
final_activation
if isinstance(final_activation, nn.Module)
else getattr(nn, final_activation)()
)
else:
layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class CriticHead(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
init_final: Optional[float] = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.net = MLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.output_layer.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.output_layer(self.net(x))
class CriticEnsemble(nn.Module):
"""
┌──────────────────┬─────────────────────────────────────────────────────────┐
│ Critic Ensemble │ │
├──────────────────┘ │
│ │
│ ┌────┐ ┌────┐ ┌────┐ │
│ │ Q1 │ │ Q2 │ │ Qn │ │
│ └────┘ └────┘ └────┘ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ │ │ │ │ │ │
│ │ MLP 1 │ │ MLP 2 │ │ MLP │ │
│ │ │ │ │ ... │ num_critics │ │
│ │ │ │ │ │ │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
│ ▲ ▲ ▲ │
│ └───────────────────┴───────┬────────────────────────────┘ │
│ │ │
│ │ │
│ ┌───────────────────┐ │
│ │ Embedding │ │
│ │ │ │
│ └───────────────────┘ │
│ ▲ │
│ │ │
│ ┌─────────────┴────────────┐ │
│ │ │ │
│ │ SACObservationEncoder │ │
│ │ │ │
│ └──────────────────────────┘ │
│ ▲ │
│ │ │
│ │ │
│ │ │
└───────────────────────────┬────────────────────┬───────────────────────────┘
│ Observation │
└────────────────────┘
"""
def __init__(
self,
encoder: Optional[nn.Module],
ensemble: List[CriticHead],
output_normalization: nn.Module,
init_final: Optional[float] = None,
):
super().__init__()
self.encoder = encoder
self.init_final = init_final
self.output_normalization = output_normalization
self.critics = nn.ModuleList(ensemble)
self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen
if self.encoder is not None:
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
self.parameters_to_optimize += list(self.critics.parameters())
def forward(
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()}
# NOTE: We normalize actions it helps for sample efficiency
actions: dict[str, torch.tensor] = {"action": actions}
# NOTE: Normalization layer took dict in input and outputs a dict that why
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
inputs = torch.cat([obs_enc, actions], dim=-1)
# Loop through critics and collect outputs
q_values = []
for critic in self.critics:
q_values.append(critic(inputs))
# Stack outputs to match expected shape [num_critics, batch_size]
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
return q_values
class Policy(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
network: nn.Module,
action_dim: int,
log_std_min: float = -5,
log_std_max: float = 2,
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
use_tanh_squash: bool = False,
encoder_is_shared: bool = False,
):
super().__init__()
self.encoder = encoder
self.network = network
self.action_dim = action_dim
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.fixed_std = fixed_std
self.use_tanh_squash = use_tanh_squash
self.parameters_to_optimize = []
self.parameters_to_optimize += list(self.network.parameters())
if self.encoder is not None and not encoder_is_shared:
self.parameters_to_optimize += list(self.encoder.parameters())
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Mean layer
self.mean_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.mean_layer.weight)
self.parameters_to_optimize += list(self.mean_layer.parameters())
# Standard deviation layer or parameter
if fixed_std is None:
self.std_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.std_layer.weight)
self.parameters_to_optimize += list(self.std_layer.parameters())
def forward(
self,
observations: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
# Get network outputs
outputs = self.network(obs_enc)
means = self.mean_layer(outputs)
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(
log_std
).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (
self.log_std_max - self.log_std_min
) * (log_std + 1.0)
else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
log_std = self.fixed_std.expand_as(means)
# uses tanh activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, torch.exp(log_std))
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log(
(1 - actions.pow(2)) + 1e-6
) # Adjust log-probs for Tanh
else:
actions = x_t # No Tanh; raw Gaussian sample
log_probs = log_probs.sum(-1) # Sum over action dimensions
means = torch.tanh(means) if self.use_tanh_squash else means
return actions, log_probs, means
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations"""
device = get_device_from_parameters(self)
observations = observations.to(device)
if self.encoder is not None:
with torch.inference_mode():
return self.encoder(observations)
return observations
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
"""
Creates encoders for pixel and/or state modalities.
"""
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if any("observation.image" in key for key in config.input_shapes):
self.camera_number = config.camera_number
if self.config.vision_encoder_name is not None:
self.image_enc_layers = PretrainedImageEncoder(config)
self.has_pretrained_vision_encoder = True
else:
self.image_enc_layers = DefaultImageEncoder(config)
self.aggregation_size += config.latent_dim * self.camera_number
if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_shapes["observation.state"][0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_shapes["observation.environment_state"][0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(
in_features=self.aggregation_size, out_features=config.latent_dim
)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
obs_dict = self.input_normalization(obs_dict)
# Batch all images along the batch dimension, then encode them.
if len(self.all_image_keys) > 0:
images_batched = torch.cat(
[obs_dict[key] for key in self.all_image_keys], dim=0
)
images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(
images_batched, dim=0, chunks=len(self.all_image_keys)
)
feat.extend(embeddings_chunks)
if "observation.environment_state" in self.config.input_shapes:
feat.append(
self.env_state_enc_layers(obs_dict["observation.environment_state"])
)
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1)
features = self.aggregation_layer(features)
return features
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self.config.latent_dim
class DefaultImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
in_channels=config.input_shapes["observation.image"][0],
out_channels=config.image_encoder_hidden_dim,
kernel_size=7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
)
def forward(self, x):
return self.image_enc_layers(x)
class PretrainedImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = (
self._load_pretrained_vision_encoder(config)
)
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
def _load_pretrained_vision_encoder(self, config):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(
config.vision_encoder_name, trust_remote_code=True
)
# self.image_enc_layers.pooler = Identity()
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[
-1
] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError(
"Unsupported vision encoder architecture, make sure you are using a CNN"
)
return self.image_enc_layers, self.image_enc_out_shape
def forward(self, x):
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
# doesn't reach the classifier layer because we don't need it
enc_feat = self.image_enc_layers(x).pooler_output
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
return enc_feat
def freeze_image_encoder(image_encoder: nn.Module):
"""Freeze all parameters in the encoder"""
for param in image_encoder.parameters():
param.requires_grad = False
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
class Identity(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
converted_params = {}
for outer_key, inner_dict in normalization_params.items():
converted_params[outer_key] = {}
for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][
key
].view(3, 1, 1)
return converted_params
if __name__ == "__main__":
# Benchmark the CriticEnsemble performance
import time
# Configuration
num_critics = 10
batch_size = 32
action_dim = 7
obs_dim = 64
hidden_dims = [256, 256]
num_iterations = 100
print("Creating test environment...")
# Create a simple dummy encoder
class DummyEncoder(nn.Module):
def __init__(self):
super().__init__()
self.output_dim = obs_dim
self.parameters_to_optimize = []
def forward(self, obs):
# Just return a random tensor of the right shape
# In practice, this would encode the observations
return torch.randn(batch_size, obs_dim, device=device)
# Create critic heads
print(f"Creating {num_critics} critic heads...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
critic_heads = [
CriticHead(
input_dim=obs_dim + action_dim,
hidden_dims=hidden_dims,
).to(device)
for _ in range(num_critics)
]
# Create the critic ensemble
print("Creating CriticEnsemble...")
critic_ensemble = CriticEnsemble(
encoder=DummyEncoder().to(device),
ensemble=critic_heads,
output_normalization=nn.Identity(),
).to(device)
# Create random input data
print("Creating input data...")
obs_dict = {
"observation.state": torch.randn(batch_size, obs_dim, device=device),
}
actions = torch.randn(batch_size, action_dim, device=device)
# Warmup run
print("Warming up...")
_ = critic_ensemble(obs_dict, actions)
# Time the forward pass
print(f"Running benchmark with {num_iterations} iterations...")
start_time = time.perf_counter()
for _ in range(num_iterations):
q_values = critic_ensemble(obs_dict, actions)
end_time = time.perf_counter()
# Print results
elapsed_time = end_time - start_time
print(f"Total time: {elapsed_time:.4f} seconds")
print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
# Verify that all critic heads produce different outputs
# This confirms each critic head is unique
# print("\nVerifying critic outputs are different:")
# for i in range(num_critics):
# for j in range(i + 1, num_critics):
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")

View File

@@ -191,10 +191,6 @@ class TDMPCConfig:
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
)
if not self.use_mpc:
raise ValueError(
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
)
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
if self.n_action_steps > self.horizon:
raise ValueError(
"`n_action_steps` must be less than or equal to `horizon`."
)
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")

View File

@@ -68,9 +68,7 @@ class TDMPCPolicy(
name = "tdmpc"
def __init__(
self,
config: TDMPCConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
):
"""
Args:
@@ -102,9 +100,7 @@ class TDMPCPolicy(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
self._use_image = False
self._use_env_state = False
@@ -124,9 +120,7 @@ class TDMPCPolicy(
"""
self._queues = {
"observation.state": deque(maxlen=1),
"action": deque(
maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)
),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
}
if self._use_image:
self._queues["observation.image"] = deque(maxlen=1)
@@ -141,9 +135,7 @@ class TDMPCPolicy(
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
if self._use_image:
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch)
@@ -217,20 +209,13 @@ class TDMPCPolicy(
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
# trajectories.
z = einops.repeat(
z,
"b d -> n b d",
n=self.config.n_gaussian_samples + self.config.n_pi_samples,
)
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
# algorithm.
# The initial mean and standard deviation for the cross-entropy method (CEM).
mean = torch.zeros(
self.config.horizon,
batch_size,
self.config.output_shapes["action"][0],
device=device,
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
)
# Maybe warm start CEM with the mean from the previous step.
if self._prev_mean is not None:
@@ -246,47 +231,35 @@ class TDMPCPolicy(
self.config.output_shapes["action"][0],
device=std.device,
)
gaussian_actions = torch.clamp(
mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1
)
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
# Compute elite actions.
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
value = self.estimate_value(z, actions).nan_to_num_(0)
elite_idxs = torch.topk(
value, self.config.n_elites, dim=0
).indices # (n_elites, batch)
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
# (horizon, n_elites, batch, action_dim)
elite_actions = actions.take_along_dim(
einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1
)
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
score = torch.exp(
self.config.elite_weighting_temperature * (elite_value - max_value)
)
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim)
_mean = torch.sum(
einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1
)
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
_std = torch.sqrt(
torch.sum(
einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d"))
** 2,
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
dim=1,
)
)
# Update mean with an exponential moving average, and std with a direct replacement.
mean = (
self.config.gaussian_mean_momentum * mean
+ (1 - self.config.gaussian_mean_momentum) * _mean
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
)
std = _std.clamp_(self.config.min_std, self.config.max_std)
@@ -295,9 +268,7 @@ class TDMPCPolicy(
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
# scores from the last iteration.
actions = elite_actions[
:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)
]
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
return actions
@@ -320,8 +291,7 @@ class TDMPCPolicy(
# of the FOWM paper.
if self.config.uncertainty_regularizer_coeff > 0:
regularization = -(
self.config.uncertainty_regularizer_coeff
* self.model.Qs(z, actions[t]).std(0)
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
)
else:
regularization = 0
@@ -341,22 +311,15 @@ class TDMPCPolicy(
if self.config.q_ensemble_size > 2:
G += (
running_discount
* torch.min(
terminal_values[
torch.randint(0, self.config.q_ensemble_size, size=(2,))
],
dim=0,
)[0]
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
0
]
)
else:
G += running_discount * torch.min(terminal_values, dim=0)[0]
# Finally, also regularize the terminal value.
if self.config.uncertainty_regularizer_coeff > 0:
G -= (
running_discount
* self.config.uncertainty_regularizer_coeff
* terminal_values.std(0)
)
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
return G
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
@@ -368,9 +331,7 @@ class TDMPCPolicy(
batch = self.normalize_inputs(batch)
if self._use_image:
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch)
@@ -388,10 +349,7 @@ class TDMPCPolicy(
# Apply random image augmentations.
if self._use_image and self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
partial(
random_shifts_aug,
max_random_shift_ratio=self.config.max_random_shift_ratio,
),
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
observations["observation.image"],
)
@@ -409,20 +367,14 @@ class TDMPCPolicy(
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(
horizon + 1, batch_size, self.config.latent_dim, device=device
)
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device)
for t in range(horizon):
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(
z_preds[t], action[t]
)
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
# Compute Q and V value predictions based on the latent rollout.
q_preds_ensemble = self.model.Qs(
z_preds[:-1], action
) # (ensemble, horizon, batch)
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
v_preds = self.model.V(z_preds[:-1])
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
@@ -436,14 +388,10 @@ class TDMPCPolicy(
# actions (not actions estimated by π).
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
# and the FOWM paper.
q_targets = reward + self.config.discount * self.model.V(
self.model.encode(next_observations)
)
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
# are using them to compute loss for V.
v_targets = self.model_target.Qs(
z_preds[:-1].detach(), action, return_min=True
)
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
# Compute losses.
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
@@ -486,9 +434,7 @@ class TDMPCPolicy(
temporal_loss_coeffs
* F.mse_loss(
q_preds_ensemble,
einops.repeat(
q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]
),
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
@@ -526,14 +472,12 @@ class TDMPCPolicy(
z_preds = z_preds.detach()
# Use stopgrad for the advantage calculation.
with torch.no_grad():
advantage = self.model_target.Qs(
z_preds[:-1], action, return_min=True
) - self.model.V(z_preds[:-1])
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
z_preds[:-1]
)
info["advantage"] = advantage[0]
# (t, b)
exp_advantage = torch.clamp(
torch.exp(advantage * self.config.advantage_scaling), max=100.0
)
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
# Calculate the MSE between the actions and the action predictions.
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
@@ -588,9 +532,7 @@ class TDMPCPolicy(
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(
self.model_target, self.model, self.config.target_model_momentum
)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
class TDMPCTOLD(nn.Module):
@@ -601,9 +543,7 @@ class TDMPCTOLD(nn.Module):
self.config = config
self._encoder = TDMPCObservationEncoder(config)
self._dynamics = nn.Sequential(
nn.Linear(
config.latent_dim + config.output_shapes["action"][0], config.mlp_dim
),
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -614,9 +554,7 @@ class TDMPCTOLD(nn.Module):
nn.Sigmoid(),
)
self._reward = nn.Sequential(
nn.Linear(
config.latent_dim + config.output_shapes["action"][0], config.mlp_dim
),
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -636,10 +574,7 @@ class TDMPCTOLD(nn.Module):
self._Qs = nn.ModuleList(
[
nn.Sequential(
nn.Linear(
config.latent_dim + config.output_shapes["action"][0],
config.mlp_dim,
),
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -684,9 +619,7 @@ class TDMPCTOLD(nn.Module):
m[-1], nn.Linear
), "Sanity check. The last linear layer needs 0 initialization on weights."
nn.init.zeros_(m[-1].weight)
nn.init.zeros_(
m[-1].bias
) # this has already been done, but keep this line here for good measure
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation."""
@@ -784,32 +717,14 @@ class TDMPCObservationEncoder(nn.Module):
if "observation.image" in config.input_shapes:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
config.input_shapes["observation.image"][0],
config.image_encoder_hidden_dim,
7,
stride=2,
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
),
nn.ReLU(),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
5,
stride=2,
),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.ReLU(),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
@@ -825,10 +740,7 @@ class TDMPCObservationEncoder(nn.Module):
)
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(
config.input_shapes["observation.state"][0],
config.state_encoder_hidden_dim,
),
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
@@ -837,8 +749,7 @@ class TDMPCObservationEncoder(nn.Module):
if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
config.input_shapes["observation.environment_state"][0],
config.state_encoder_hidden_dim,
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
@@ -855,15 +766,9 @@ class TDMPCObservationEncoder(nn.Module):
feat = []
# NOTE: Order of observations matters here.
if "observation.image" in self.config.input_shapes:
feat.append(
flatten_forward_unflatten(
self.image_enc_layers, obs_dict["observation.image"]
)
)
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
if "observation.environment_state" in self.config.input_shapes:
feat.append(
self.env_state_enc_layers(obs_dict["observation.environment_state"])
)
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
return torch.stack(feat, dim=0).mean(0)
@@ -906,17 +811,12 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
for (n_p_ema, p_ema), (n_p, p) in zip(
ema_module.named_parameters(recurse=False),
module.named_parameters(recurse=False),
strict=True,
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict):
raise RuntimeError("Dict parameter not supported")
if (
isinstance(module, nn.modules.batchnorm._BatchNorm)
or not p.requires_grad
):
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
# Copy BatchNorm parameters, and non-trainable parameters directly.
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
with torch.no_grad():
@@ -924,9 +824,7 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
def flatten_forward_unflatten(
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
) -> Tensor:
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:

View File

@@ -0,0 +1,193 @@
#!/usr/bin/env python
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
# and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
@dataclass
class TDMPC2Config:
"""Configuration class for TDMPC2Policy.
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
Args:
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
action repeats in Q-learning or ask your favorite chatbot)
horizon: Horizon for model predictive control.
n_action_steps: Number of action steps to take from the plan given by model predictive control. This
is an alternative to using action repeats. If this is set to more than 1, then we require
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
approach of using multiple steps from the plan is not in the original implementation.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
match the original implementation.
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
normalization mode here.
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
latent_dim: Observation's latent embedding dimension.
q_ensemble_size: Number of Q function estimators to use in an ensemble for uncertainty estimation.
mlp_dim: Hidden dimension of MLPs used for modelling the dynamics encoder, reward function, policy
(π), Q ensemble, and V.
discount: Discount factor (γ) to use for the reinforcement learning formalism.
use_mpc: Whether to use model predictive control. The alternative is to just sample the policy model
(π) for each step.
cem_iterations: Number of iterations for the MPPI/CEM loop in MPC.
max_std: Maximum standard deviation for actions sampled from the gaussian PDF in CEM.
min_std: Minimum standard deviation for noise applied to actions sampled from the policy model (π).
Doubles up as the minimum standard deviation for actions sampled from the gaussian PDF in CEM.
n_gaussian_samples: Number of samples to draw from the gaussian distribution every CEM iteration. Must
be non-zero.
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
be zero.
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
elites, when updating the gaussian parameters for CEM.
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
is applied. Note that the input images are assumed to be square for this augmentation.
reward_coeff: Loss weighting coefficient for the reward regression loss.
value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state
value (V) expectile regression loss.
consistency_coeff: Loss weighting coefficient for the consistency loss.
temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time-
steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the
current time step.
target_model_momentum: Momentum (α) used for EMA updates of the target models. Updates are calculated
as ϕ ← αϕ + (1-α)θ where ϕ are the parameters of the target model and θ are the parameters of the
model being trained.
"""
# Input / output structure.
n_action_repeats: int = 1
horizon: int = 3
n_action_steps: int = 1
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)
# Architecture / modeling.
# Neural networks.
image_encoder_hidden_dim: int = 32
state_encoder_hidden_dim: int = 256
latent_dim: int = 512
q_ensemble_size: int = 5
num_enc_layers: int = 2
mlp_dim: int = 512
# Reinforcement learning.
discount: float = 0.9
simnorm_dim: int = 8
dropout: float = 0.01
# actor
log_std_min: float = -10
log_std_max: float = 2
# critic
num_bins: int = 101
vmin: int = -10
vmax: int = +10
# Inference.
use_mpc: bool = True
cem_iterations: int = 6
max_std: float = 2.0
min_std: float = 0.05
n_gaussian_samples: int = 512
n_pi_samples: int = 24
n_elites: int = 64
elite_weighting_temperature: float = 0.5
# Training and loss computation.
max_random_shift_ratio: float = 0.0476
# Loss coefficients.
reward_coeff: float = 0.1
value_coeff: float = 0.1
consistency_coeff: float = 20.0
entropy_coef: float = 1e-4
temporal_decay_coeff: float = 0.5
# Target model. NOTE (michel_aractingi) this is equivelant to
# 1 - target_model_momentum of our TD-MPC1 implementation because
# of the use of `torch.lerp`
target_model_momentum: float = 0.01
def __post_init__(self):
"""Input validation (not exhaustive)."""
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) > 1:
raise ValueError(
f"{self.__class__.__name__} handles at most one image for now. Got image keys {image_keys}."
)
if len(image_keys) > 0:
image_key = next(iter(image_keys))
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
)
if self.n_gaussian_samples <= 0:
raise ValueError(
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
)
if self.output_normalization_modes != {"action": "min_max"}:
raise ValueError(
"TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly "
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
"information."
)
if self.n_action_steps > 1:
if self.n_action_repeats != 1:
raise ValueError(
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
)
if not self.use_mpc:
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
if self.n_action_steps > self.horizon:
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")

View File

@@ -0,0 +1,834 @@
#!/usr/bin/env python
# Copyright 2024 Nicklas Hansen and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of TD-MPC2: Scalable, Robust World Models for Continuous Control
We refer to the main paper and codebase:
TD-MPC2 paper: (https://arxiv.org/abs/2310.16828)
TD-MPC2 code: (https://github.com/nicklashansen/tdmpc2)
"""
# ruff: noqa: N806
from collections import deque
from copy import deepcopy
from functools import partial
from typing import Callable
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
from lerobot.common.policies.tdmpc2.tdmpc2_utils import (
NormedLinear,
SimNorm,
gaussian_logprob,
soft_cross_entropy,
squash,
two_hot_inv,
)
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
class TDMPC2Policy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "tdmpc2"],
):
"""Implementation of TD-MPC2 learning + inference."""
name = "tdmpc2"
def __init__(
self, config: TDMPC2Config | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
if config is None:
config = TDMPC2Config()
self.config = config
self.model = TDMPC2WorldModel(config)
# TODO (michel-aractingi) temp fix for gpu
self.model = self.model.to("cuda:0")
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
else:
self.normalize_inputs = nn.Identity()
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
self._use_image = False
self._use_env_state = False
if len(image_keys) > 0:
assert len(image_keys) == 1
self._use_image = True
self.input_image_key = image_keys[0]
if "observation.environment_state" in config.input_shapes:
self._use_env_state = True
self.scale = RunningScale(self.config.target_model_momentum)
self.discount = (
self.config.discount
) # TODO (michel-aractingi) downscale discount according to episode length
self.reset()
def reset(self):
"""
Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be
called on `env.reset()`
"""
self._queues = {
"observation.state": deque(maxlen=1),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
}
if self._use_image:
self._queues["observation.image"] = deque(maxlen=1)
if self._use_env_state:
self._queues["observation.environment_state"] = deque(maxlen=1)
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
# CEM for the next step.
self._prev_mean: torch.Tensor | None = None
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
if self._use_image:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch)
# When the action queue is depleted, populate it again by querying the policy.
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
# Remove the time dimensions as it is not handled yet.
for key in batch:
assert batch[key].shape[1] == 1
batch[key] = batch[key][:, 0]
# NOTE: Order of observations matters here.
encode_keys = []
if self._use_image:
encode_keys.append("observation.image")
if self._use_env_state:
encode_keys.append("observation.environment_state")
encode_keys.append("observation.state")
z = self.model.encode({k: batch[k] for k in encode_keys})
if self.config.use_mpc: # noqa: SIM108
actions = self.plan(z) # (horizon, batch, action_dim)
else:
# Plan with the policy (π) alone. This always returns one action so unsqueeze to get a
# sequence dimension like in the MPC branch.
actions = self.model.pi(z)[0].unsqueeze(0)
actions = torch.clamp(actions, -1, +1)
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.n_action_repeats > 1:
for _ in range(self.config.n_action_repeats):
self._queues["action"].append(actions[0])
else:
# Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action.
self._queues["action"].extend(actions[: self.config.n_action_steps])
action = self._queues["action"].popleft()
return action
@torch.no_grad()
def plan(self, z: Tensor) -> Tensor:
"""Plan sequence of actions using TD-MPC inference.
Args:
z: (batch, latent_dim,) tensor for the initial state.
Returns:
(horizon, batch, action_dim,) tensor for the planned trajectory of actions.
"""
device = get_device_from_parameters(self)
batch_size = z.shape[0]
# Sample Nπ trajectories from the policy.
pi_actions = torch.empty(
self.config.horizon,
self.config.n_pi_samples,
batch_size,
self.config.output_shapes["action"][0],
device=device,
)
if self.config.n_pi_samples > 0:
_z = einops.repeat(z, "b d -> n b d", n=self.config.n_pi_samples)
for t in range(self.config.horizon):
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
# helpful for CEM.
pi_actions[t] = self.model.pi(_z)[0]
_z = self.model.latent_dynamics(_z, pi_actions[t])
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
# trajectories.
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
# algorithm.
# The initial mean and standard deviation for the cross-entropy method (CEM).
mean = torch.zeros(
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
)
# Maybe warm start CEM with the mean from the previous step.
if self._prev_mean is not None:
mean[:-1] = self._prev_mean[1:]
std = self.config.max_std * torch.ones_like(mean)
for _ in range(self.config.cem_iterations):
# Randomly sample action trajectories for the gaussian distribution.
std_normal_noise = torch.randn(
self.config.horizon,
self.config.n_gaussian_samples,
batch_size,
self.config.output_shapes["action"][0],
device=std.device,
)
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
# Compute elite actions.
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
value = self.estimate_value(z, actions).nan_to_num_(0).squeeze()
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
# (horizon, n_elites, batch, action_dim)
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim)
mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (
einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9
)
std = torch.sqrt(
torch.sum(
einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2,
dim=1,
)
/ (einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9)
).clamp_(self.config.min_std, self.config.max_std)
# Keep track of the mean for warm-starting subsequent steps.
self._prev_mean = mean
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
# scores from the last iteration.
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
return actions
@torch.no_grad()
def estimate_value(self, z: Tensor, actions: Tensor):
"""Estimates the value of a trajectory as per eqn 4 of the FOWM paper.
Args:
z: (batch, latent_dim) tensor of initial latent states.
actions: (horizon, batch, action_dim) tensor of action trajectories.
Returns:
(batch,) tensor of values.
"""
# Initialize return and running discount factor.
G, running_discount = 0, 1
# Iterate over the actions in the trajectory to simulate the trajectory using the latent dynamics
# model. Keep track of return.
for t in range(actions.shape[0]):
# Estimate the next state (latent) and reward.
z, reward = self.model.latent_dynamics_and_reward(z, actions[t], discretize_reward=True)
# Update the return and running discount.
G += running_discount * reward
running_discount *= self.config.discount
# next_action = self.model.pi(z)[0] # (batch, action_dim)
# terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type="avg")
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats.
"""
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
if self._use_image:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch)
info = {}
# (b, t) -> (t, b)
for key in batch:
if batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
action = batch["action"] # (t, b, action_dim)
reward = batch["next.reward"] # (t, b)
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# Apply random image augmentations.
if self._use_image and self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
observations["observation.image"],
)
# Get the current observation for predicting trajectories, and all future observations for use in
# the latent consistency loss and TD loss.
current_observation, next_observations = {}, {}
for k in observations:
current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:]
horizon, batch_size = next_observations[
"observation.image" if self._use_image else "observation.environment_state"
].shape[:2]
# Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty(horizon, batch_size, self.config.num_bins, device=device)
for t in range(horizon):
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
# Compute Q value predictions based on the latent rollout.
q_preds_ensemble = self.model.Qs(
z_preds[:-1], action, return_type="all"
) # (ensemble, horizon, batch)
info.update({"Q": q_preds_ensemble.mean().item()})
# Compute various targets with stopgrad.
with torch.no_grad():
# Latent state consistency targets for consistency loss.
z_targets = self.model.encode(next_observations)
# Compute the TD-target from a reward and the next observation
pi = self.model.pi(z_targets)[0]
td_targets = (
reward
+ self.config.discount
* self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
)
# Compute losses.
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
# future have less impact on the loss. Note: unsqueeze will let us broadcast to (seq, batch).
temporal_loss_coeffs = torch.pow(
self.config.temporal_decay_coeff, torch.arange(horizon, device=device)
).unsqueeze(-1)
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
# predicted from the (target model's) observation encoder.
consistency_loss = (
(
temporal_loss_coeffs
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
# `z_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# `z_targets` depends on the next observation.
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
)
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
# rewards.
reward_loss = (
(
temporal_loss_coeffs
* soft_cross_entropy(reward_preds, reward, self.config).mean(1)
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
)
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
ce_value_loss = 0.0
for i in range(self.config.q_ensemble_size):
ce_value_loss += soft_cross_entropy(q_preds_ensemble[i], td_targets, self.config).mean(1)
q_value_loss = (
(
temporal_loss_coeffs
* ce_value_loss
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
)
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
# We won't need these gradients again so detach.
z_preds = z_preds.detach()
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
with torch.no_grad():
# avoid unnessecary computation of the gradients during policy optimization
# TODO (michel-aractingi): the same logic should be extended when adding task embeddings
qs = self.model.Qs(z_preds[:-1], action_preds, return_type="avg")
self.scale.update(qs[0])
qs = self.scale(qs)
pi_loss = (
(self.config.entropy_coef * log_pis - qs).mean(dim=2)
* temporal_loss_coeffs
# `action_preds` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
).mean()
loss = (
self.config.consistency_coeff * consistency_loss
+ self.config.reward_coeff * reward_loss
+ self.config.value_coeff * q_value_loss
+ pi_loss
)
info.update(
{
"consistency_loss": consistency_loss.item(),
"reward_loss": reward_loss.item(),
"Q_value_loss": q_value_loss.item(),
"pi_loss": pi_loss.item(),
"loss": loss,
"sum_loss": loss.item() * self.config.horizon,
"pi_scale": float(self.scale.value),
}
)
# Undo (b, t) -> (t, b).
for key in batch:
if batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
return info
def update(self):
"""Update the target model's using polyak averaging."""
self.model.update_target_Q()
class TDMPC2WorldModel(nn.Module):
"""Latent dynamics model used in TD-MPC2."""
def __init__(self, config: TDMPC2Config):
super().__init__()
self.config = config
self._encoder = TDMPC2ObservationEncoder(config)
# Define latent dynamics head
self._dynamics = nn.Sequential(
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim),
NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)),
)
# Define reward head
self._reward = nn.Sequential(
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
)
# Define policy head
self._pi = nn.Sequential(
NormedLinear(config.latent_dim, config.mlp_dim),
NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0]),
)
# Define ensemble of Q functions
self._Qs = nn.ModuleList(
[
nn.Sequential(
NormedLinear(
config.latent_dim + config.output_shapes["action"][0],
config.mlp_dim,
dropout=config.dropout,
),
NormedLinear(config.mlp_dim, config.mlp_dim),
nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
)
for _ in range(config.q_ensemble_size)
]
)
self._init_weights()
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
self.log_std_min = torch.tensor(config.log_std_min)
self.log_std_dif = torch.tensor(config.log_std_max) - self.log_std_min
self.bins = torch.linspace(config.vmin, config.vmax, config.num_bins)
self.config.bin_size = (config.vmax - config.vmin) / (config.num_bins - 1)
def _init_weights(self):
"""Initialize model weights.
Custom weight initializations proposed in TD-MPC2.
"""
def _apply_fn(m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.ParameterList):
for i, p in enumerate(m):
if p.dim() == 3: # Linear
nn.init.trunc_normal_(p, std=0.02) # Weight
nn.init.constant_(m[i + 1], 0) # Bias
self.apply(_apply_fn)
# initialize parameters of the
for m in [self._reward, *self._Qs]:
assert isinstance(
m[-1], nn.Linear
), "Sanity check. The last linear layer needs 0 initialization on weights."
nn.init.zeros_(m[-1].weight)
def to(self, *args, **kwargs):
"""
Overriding `to` method to also move additional tensors to device.
"""
super().to(*args, **kwargs)
self.log_std_min = self.log_std_min.to(*args, **kwargs)
self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
self.bins = self.bins.to(*args, **kwargs)
return self
def train(self, mode):
super().train(mode)
self._target_Qs.train(False)
return self
def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation."""
return self._encoder(obs)
def latent_dynamics_and_reward(
self, z: Tensor, a: Tensor, discretize_reward: bool = False
) -> tuple[Tensor, Tensor, bool]:
"""Predict the next state's latent representation and the reward given a current latent and action.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
a: (*, action_dim) tensor for the action to be applied.
Returns:
A tuple containing:
- (*, latent_dim) tensor for the next state's latent representation.
- (*,) tensor for the estimated reward.
"""
x = torch.cat([z, a], dim=-1)
reward = self._reward(x).squeeze(-1)
if discretize_reward:
reward = two_hot_inv(reward, self.bins)
return self._dynamics(x), reward
def latent_dynamics(self, z: Tensor, a: Tensor) -> Tensor:
"""Predict the next state's latent representation given a current latent and action.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
a: (*, action_dim) tensor for the action to be applied.
Returns:
(*, latent_dim) tensor for the next state's latent representation.
"""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x)
def pi(self, z: Tensor) -> Tensor:
"""Samples an action from the learned policy.
The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when
generating rollouts for online training.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
std: The standard deviation of the injected noise.
Returns:
(*, action_dim) tensor for the sampled action.
"""
mu, log_std = self._pi(z).chunk(2, dim=-1)
log_std = self.log_std_min + 0.5 * self.log_std_dif * (torch.tanh(log_std) + 1)
eps = torch.randn_like(mu)
log_pi = gaussian_logprob(eps, log_std)
pi = mu + eps * log_std.exp()
mu, pi, log_pi = squash(mu, pi, log_pi)
return pi, mu, log_pi, log_std
def Qs(self, z: Tensor, a: Tensor, return_type: str = "min", target=False) -> Tensor: # noqa: N802
"""Predict state-action value for all of the learned Q functions.
Args:
z: (*, latent_dim) tensor for the current state's latent representation.
a: (*, action_dim) tensor for the action to be applied.
return_type: either 'min' or 'all' otherwise the average is returned
Returns:
(q_ensemble, *) tensor for the value predictions of each learned Q function in the ensemble or the average or min
"""
x = torch.cat([z, a], dim=-1)
if target:
out = torch.stack([q(x).squeeze(-1) for q in self._target_Qs], dim=0)
else:
out = torch.stack([q(x).squeeze(-1) for q in self._Qs], dim=0)
if return_type == "all":
return out
Q1, Q2 = out[np.random.choice(len(self._Qs), size=2, replace=False)]
Q1, Q2 = two_hot_inv(Q1, self.bins), two_hot_inv(Q2, self.bins)
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
def update_target_Q(self):
"""
Soft-update target Q-networks using Polyak averaging.
"""
with torch.no_grad():
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters(), strict=False):
p_target.data.lerp_(p.data, self.config.target_model_momentum)
class TDMPC2ObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: TDMPC2Config):
"""
Creates encoders for pixel and/or state modalities.
TODO(alexander-soare): The original work allows for multiple images by concatenating them along the
channel dimension. Re-implement this capability.
"""
super().__init__()
self.config = config
# Define the observation encoder whether its pixels or states
encoder_dict = {}
for obs_key in config.input_shapes:
if "observation.image" in config.input_shapes:
encoder_module = nn.Sequential(
nn.Conv2d(config.input_shapes[obs_key][0], config.image_encoder_hidden_dim, 7, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=1),
)
dummy_batch = torch.zeros(1, *config.input_shapes[obs_key])
with torch.inference_mode():
out_shape = encoder_module(dummy_batch).shape[1:]
encoder_module.extend(
nn.Sequential(
nn.Flatten(),
NormedLinear(np.prod(out_shape), config.latent_dim, act=SimNorm(config.simnorm_dim)),
)
)
elif (
"observation.state" in config.input_shapes
or "observation.environment_state" in config.input_shapes
):
encoder_module = nn.ModuleList()
encoder_module.append(
NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)
)
assert config.num_enc_layers > 0
for _ in range(config.num_enc_layers - 1):
encoder_module.append(
NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)
)
encoder_module.append(
NormedLinear(
config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)
)
)
encoder_module = nn.Sequential(*encoder_module)
else:
raise NotImplementedError(f"No corresponding encoder module for key {obs_key}.")
encoder_dict[obs_key.replace(".", "")] = encoder_module
self.encoder = nn.ModuleDict(encoder_dict)
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
for obs_key in self.config.input_shapes:
if "observation.image" in obs_key:
feat.append(
flatten_forward_unflatten(self.encoder[obs_key.replace(".", "")], obs_dict[obs_key])
)
else:
feat.append(self.encoder[obs_key.replace(".", "")](obs_dict[obs_key]))
return torch.stack(feat, dim=0).mean(0)
def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor:
"""Randomly shifts images horizontally and vertically.
Adapted from https://github.com/facebookresearch/drqv2
"""
b, _, h, w = x.size()
assert h == w, "non-square images not handled yet"
pad = int(round(max_random_shift_ratio * h))
x = F.pad(x, tuple([pad] * 4), "replicate")
eps = 1.0 / (h + 2 * pad)
arange = torch.linspace(
-1.0 + eps,
1.0 - eps,
h + 2 * pad,
device=x.device,
dtype=torch.float32,
)[:h]
arange = einops.repeat(arange, "w -> h w 1", h=h)
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
base_grid = einops.repeat(base_grid, "h w c -> b h w c", b=b)
# A random shift in units of pixels and within the boundaries of the padding.
shift = torch.randint(
0,
2 * pad + 1,
size=(b, 1, 1, 2),
device=x.device,
dtype=torch.float32,
)
shift *= 2.0 / (h + 2 * pad)
grid = base_grid + shift
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
(B, *), where * is any number of dimensions.
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions, generally
different from *.
Returns:
A return value from the callable reshaped to (**, *).
"""
if image_tensor.ndim == 4:
return fn(image_tensor)
start_dims = image_tensor.shape[:-3]
inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
class RunningScale:
"""Running trimmed scale estimator."""
def __init__(self, tau):
self.tau = tau
self._value = torch.ones(1, dtype=torch.float32, device=torch.device("cuda"))
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device("cuda"))
def state_dict(self):
return dict(value=self._value, percentiles=self._percentiles)
def load_state_dict(self, state_dict):
self._value.data.copy_(state_dict["value"])
self._percentiles.data.copy_(state_dict["percentiles"])
@property
def value(self):
return self._value.cpu().item()
def _percentile(self, x):
x_dtype, x_shape = x.dtype, x.shape
x = x.view(x.shape[0], -1)
in_sorted, _ = torch.sort(x, dim=0)
positions = self._percentiles * (x.shape[0] - 1) / 100
floored = torch.floor(positions)
ceiled = floored + 1
ceiled[ceiled > x.shape[0] - 1] = x.shape[0] - 1
weight_ceiled = positions - floored
weight_floored = 1.0 - weight_ceiled
d0 = in_sorted[floored.long(), :] * weight_floored[:, None]
d1 = in_sorted[ceiled.long(), :] * weight_ceiled[:, None]
return (d0 + d1).view(-1, *x_shape[1:]).type(x_dtype)
def update(self, x):
percentiles = self._percentile(x.detach())
value = torch.clamp(percentiles[1] - percentiles[0], min=1.0)
self._value.data.lerp_(value, self.tau)
def __call__(self, x, update=False):
if update:
self.update(x)
return x * (1 / self.value)
def __repr__(self):
return f"RunningScale(S: {self.value})"

View File

@@ -0,0 +1,164 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from functorch import combine_state_for_ensemble
class Ensemble(nn.Module):
"""
Vectorized ensemble of modules.
"""
def __init__(self, modules, **kwargs):
super().__init__()
modules = nn.ModuleList(modules)
fn, params, _ = combine_state_for_ensemble(modules)
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness="different", **kwargs)
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
self._repr = str(modules)
def forward(self, *args, **kwargs):
return self.vmap([p for p in self.params], (), *args, **kwargs)
def __repr__(self):
return "Vectorized " + self._repr
class SimNorm(nn.Module):
"""
Simplicial normalization.
Adapted from https://arxiv.org/abs/2204.00616.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
shp = x.shape
x = x.view(*shp[:-1], -1, self.dim)
x = F.softmax(x, dim=-1)
return x.view(*shp)
def __repr__(self):
return f"SimNorm(dim={self.dim})"
class NormedLinear(nn.Linear):
"""
Linear layer with LayerNorm, activation, and optionally dropout.
"""
def __init__(self, *args, dropout=0.0, act=nn.Mish(inplace=True), **kwargs):
super().__init__(*args, **kwargs)
self.ln = nn.LayerNorm(self.out_features)
self.act = act
self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None
def forward(self, x):
x = super().forward(x)
if self.dropout:
x = self.dropout(x)
return self.act(self.ln(x))
def __repr__(self):
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
return (
f"NormedLinear(in_features={self.in_features}, "
f"out_features={self.out_features}, "
f"bias={self.bias is not None}{repr_dropout}, "
f"act={self.act.__class__.__name__})"
)
def soft_cross_entropy(pred, target, cfg):
"""Computes the cross entropy loss between predictions and soft targets."""
pred = F.log_softmax(pred, dim=-1)
target = two_hot(target, cfg)
return -(target * pred).sum(-1, keepdim=True)
@torch.jit.script
def log_std(x, low, dif):
return low + 0.5 * dif * (torch.tanh(x) + 1)
@torch.jit.script
def _gaussian_residual(eps, log_std):
return -0.5 * eps.pow(2) - log_std
@torch.jit.script
def _gaussian_logprob(residual):
return residual - 0.5 * torch.log(2 * torch.pi)
def gaussian_logprob(eps, log_std, size=None):
"""Compute Gaussian log probability."""
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
if size is None:
size = eps.size(-1)
return _gaussian_logprob(residual) * size
@torch.jit.script
def _squash(pi):
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
def squash(mu, pi, log_pi):
"""Apply squashing function."""
mu = torch.tanh(mu)
pi = torch.tanh(pi)
log_pi -= _squash(pi).sum(-1, keepdim=True)
return mu, pi, log_pi
@torch.jit.script
def symlog(x):
"""
Symmetric logarithmic function.
Adapted from https://github.com/danijar/dreamerv3.
"""
return torch.sign(x) * torch.log(1 + torch.abs(x))
@torch.jit.script
def symexp(x):
"""
Symmetric exponential function.
Adapted from https://github.com/danijar/dreamerv3.
"""
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
def two_hot(x, cfg):
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
# x shape [horizon, num_features]
if cfg.num_bins == 0:
return x
elif cfg.num_bins == 1:
return symlog(x)
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax)
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() # shape [num_features]
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) # shape [num_features , 1]
soft_two_hot = torch.zeros(
*x.shape, cfg.num_bins, device=x.device
) # shape [horizon, num_features, num_bins]
soft_two_hot.scatter_(2, bin_idx.unsqueeze(-1), 1 - bin_offset)
soft_two_hot.scatter_(2, (bin_idx.unsqueeze(-1) + 1) % cfg.num_bins, bin_offset)
return soft_two_hot
def two_hot_inv(x, bins):
"""Converts a batch of soft two-hot encoded vectors to scalars."""
num_bins = bins.shape[0]
if num_bins == 0:
return x
elif num_bins == 1:
return symexp(x)
x = F.softmax(x, dim=-1)
x = torch.sum(x * bins, dim=-1, keepdim=True)
return symexp(x)

View File

@@ -109,9 +109,7 @@ class VQBeTConfig:
"observation.state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
# Architecture / modeling.
# Vision backbone.

View File

@@ -79,9 +79,7 @@ class VQBeTPolicy(
self.vqbet = VQBeTModel(config)
self.expected_image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.reset()
@@ -106,12 +104,8 @@ class VQBeTPolicy(
"""
batch = self.normalize_inputs(batch)
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[k] for k in self.expected_image_keys], dim=-4
)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
@@ -122,14 +116,8 @@ class VQBeTPolicy(
)
if len(self._queues["action"]) == 0:
batch = {
k: torch.stack(list(self._queues[k]), dim=1)
for k in batch
if k in self._queues
}
actions = self.vqbet(batch, rollout=True)[
:, : self.config.action_chunk_size
]
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
@@ -142,12 +130,8 @@ class VQBeTPolicy(
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[k] for k in self.expected_image_keys], dim=-4
)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self.vqbet.action_head.vqvae_model.discretized.item():
@@ -155,9 +139,7 @@ class VQBeTPolicy(
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
loss, n_different_codes, n_different_combinations, recon_l1_error = (
self.vqbet.action_head.discretize(
self.config.n_vqvae_training_steps, batch["action"]
)
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
)
return {
"loss": loss,
@@ -214,9 +196,7 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
@@ -308,17 +288,14 @@ class VQBeTModel(nn.Module):
self.config = config
self.rgb_encoder = VQBeTRgbEncoder(config)
self.num_images = len(
[k for k in config.input_shapes if k.startswith("observation.image")]
)
self.num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
# Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
self.state_projector = MLP(
config.input_shapes["observation.state"][0],
hidden_channels=[self.config.gpt_input_dim],
config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim]
)
self.rgb_feature_projector = MLP(
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
@@ -333,12 +310,7 @@ class VQBeTModel(nn.Module):
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
self.register_buffer(
"select_target_actions_indices",
torch.row_stack(
[
torch.arange(i, i + self.config.action_chunk_size)
for i in range(num_tokens)
]
),
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
)
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
@@ -353,11 +325,7 @@ class VQBeTModel(nn.Module):
)
# Separate batch and sequence dims.
img_features = einops.rearrange(
img_features,
"(b s n) ... -> b s n ...",
b=batch_size,
s=n_obs_steps,
n=self.num_images,
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
)
# Arrange prior and current observation step tokens as shown in the class docstring.
@@ -369,19 +337,13 @@ class VQBeTModel(nn.Module):
input_tokens.append(
self.state_projector(batch["observation.state"])
) # (batch, obs_step, projection dims)
input_tokens.append(
einops.repeat(
self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps
)
)
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
# Interleave tokens by stacking and rearranging.
input_tokens = torch.stack(input_tokens, dim=2)
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
len_additional_action_token = self.config.n_action_pred_token - 1
future_action_tokens = self.action_token.repeat(
batch_size, len_additional_action_token, 1
)
future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
# add additional action query tokens for predicting future action chunks
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
@@ -390,9 +352,9 @@ class VQBeTModel(nn.Module):
features = self.policy(input_tokens)
# len(self.config.input_shapes) is the number of different observation modes.
# this line gets the index of action prompt tokens.
historical_act_pred_index = np.arange(0, n_obs_steps) * (
len(self.config.input_shapes) + 1
) + len(self.config.input_shapes)
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len(
self.config.input_shapes
)
# only extract the output tokens at the position of action query:
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
@@ -400,11 +362,7 @@ class VQBeTModel(nn.Module):
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
if len_additional_action_token > 0:
features = torch.cat(
[
features[:, historical_act_pred_index],
features[:, -len_additional_action_token:],
],
dim=1,
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
)
else:
features = features[:, historical_act_pred_index]
@@ -412,15 +370,13 @@ class VQBeTModel(nn.Module):
action_head_output = self.action_head(features)
# if rollout, VQ-BeT don't calculate loss
if rollout:
return action_head_output["predicted_action"][
:, n_obs_steps - 1, :
].reshape(batch_size, self.config.action_chunk_size, -1)
return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
batch_size, self.config.action_chunk_size, -1
)
# else, it calculate overall loss (bin prediction loss, and offset loss)
else:
output = batch["action"][:, self.select_target_actions_indices]
loss = self.action_head.loss_fn(
action_head_output, output, reduction="mean"
)
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
return action_head_output, loss
@@ -455,9 +411,7 @@ class VQBeTHead(nn.Module):
else:
self.map_to_cbet_preds_bin = MLP(
in_channels=config.gpt_output_dim,
hidden_channels=[
self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed
],
hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
)
self.map_to_cbet_preds_offset = MLP(
in_channels=config.gpt_output_dim,
@@ -484,10 +438,7 @@ class VQBeTHead(nn.Module):
loss, metric = self.vqvae_model.vqvae_forward(actions)
n_different_codes = sum(
[
len(torch.unique(metric[2][:, i]))
for i in range(self.vqvae_model.vqvae_num_layers)
]
[len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)]
)
n_different_combinations = len(torch.unique(metric[2], dim=0))
recon_l1_error = metric[0].detach().cpu().item()
@@ -534,13 +485,7 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
torch.cat(
(
x,
F.one_hot(
sampled_primary_centers,
num_classes=self.config.vqvae_n_embed,
),
),
(x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
axis=1,
)
)
@@ -548,29 +493,19 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
)
sampled_secondary_centers = einops.rearrange(
torch.multinomial(
cbet_secondary_probs.view(-1, choices), num_samples=1
),
torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1),
"(NT) 1 -> NT",
NT=NT,
)
sampled_centers = torch.stack(
(sampled_primary_centers, sampled_secondary_centers), axis=1
)
cbet_logits = torch.stack(
[cbet_primary_logits, cbet_secondary_logits], dim=1
)
sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1)
cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1)
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
else:
cbet_logits = self.map_to_cbet_preds_bin(x)
cbet_logits = einops.rearrange(
cbet_logits,
"(NT) (G C) -> (NT) G C",
G=self.vqvae_model.vqvae_num_layers,
)
cbet_probs = torch.softmax(
cbet_logits / self.config.bet_softmax_temperature, dim=-1
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
)
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
NT, G, choices = cbet_probs.shape
sampled_centers = einops.rearrange(
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
@@ -590,17 +525,9 @@ class VQBeTHead(nn.Module):
sampled_offsets = sampled_offsets.sum(dim=1)
with torch.no_grad():
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
return_decoder_input = (
self.vqvae_model.get_embeddings_from_code(sampled_centers)
.clone()
.detach()
)
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
# pass the centroids through decoder to get actions.
decoded_action = (
self.vqvae_model.get_action_from_latent(return_decoder_input)
.clone()
.detach()
)
decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach()
# reshaped extracted offset to match with decoded centroids
sampled_offsets = einops.rearrange(
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
@@ -649,9 +576,7 @@ class VQBeTHead(nn.Module):
# Figure out the loss for the actions.
# First, we need to find the closest cluster center for each ground truth action.
with torch.no_grad():
state_vq, action_bins = self.vqvae_model.get_code(
action_seq
) # action_bins: NT, G
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
# Now we can compute the loss.
@@ -674,12 +599,8 @@ class VQBeTHead(nn.Module):
+ cbet_loss2 * self.config.secondary_code_loss_weight
)
equal_primary_code_rate = torch.sum(
(action_bins[:, 0] == sampled_centers[:, 0]).int()
) / (NT)
equal_secondary_code_rate = torch.sum(
(action_bins[:, 1] == sampled_centers[:, 1]).int()
) / (NT)
equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
@@ -693,9 +614,7 @@ class VQBeTHead(nn.Module):
"classification_loss": cbet_loss.detach().cpu().item(),
"offset_loss": offset_loss.detach().cpu().item(),
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach()
.cpu()
.item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
"vq_action_error": vq_action_error.detach().cpu().item(),
"offset_action_error": offset_action_error.detach().cpu().item(),
"action_error_max": action_error_max.detach().cpu().item(),
@@ -724,17 +643,11 @@ class VQBeTOptimizer(torch.optim.Adam):
if cfg.policy.sequentially_select:
decay_params = (
decay_params
+ list(
policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()
)
+ list(
policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()
)
+ list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
)
else:
decay_params = decay_params + list(
policy.vqbet.action_head.map_to_cbet_preds_bin.parameters()
)
decay_params = decay_params + list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
optim_groups = [
{
@@ -780,11 +693,7 @@ class VQBeTScheduler(nn.Module):
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(
0.0,
0.5
* (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
self.lr_scheduler = LambdaLR(optimizer, lr_lambda, -1)
@@ -808,9 +717,7 @@ class VQBeTRgbEncoder(nn.Module):
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(
config.crop_shape
)
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
@@ -831,9 +738,7 @@ class VQBeTRgbEncoder(nn.Module):
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# Set up pooling and final layers.
@@ -841,25 +746,17 @@ class VQBeTRgbEncoder(nn.Module):
# The dummy input should take the number of image channels from `config.input_shapes` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`.
image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1
image_key = image_keys[0]
dummy_input_h_w = (
config.crop_shape
if config.crop_shape is not None
else config.input_shapes[image_key][1:]
)
dummy_input = torch.zeros(
size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
)
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])
self.pool = SpatialSoftmax(
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
)
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
@@ -886,9 +783,7 @@ class VQBeTRgbEncoder(nn.Module):
def _replace_submodules(
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
Args:
@@ -901,11 +796,7 @@ def _replace_submodules(
if predicate(root_module):
return func(root_module)
replace_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
@@ -920,9 +811,7 @@ def _replace_submodules(
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
return root_module
@@ -955,8 +844,7 @@ class VqVae(nn.Module):
)
self.encoder = MLP(
in_channels=self.config.output_shapes["action"][0]
* self.config.action_chunk_size,
in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size,
hidden_channels=[
config.vqvae_enc_hidden_dim,
config.vqvae_enc_hidden_dim,
@@ -984,13 +872,9 @@ class VqVae(nn.Module):
# given latent vector, this function outputs the decoded action.
output = self.decoder(latent)
if self.config.action_chunk_size == 1:
return einops.rearrange(
output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]
)
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
else:
return einops.rearrange(
output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]
)
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
def get_code(self, state):
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)

View File

@@ -123,15 +123,9 @@ class CausalSelfAttention(nn.Module):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2)
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
@@ -139,9 +133,7 @@ class CausalSelfAttention(nn.Module):
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = (
y.transpose(1, 2).contiguous().view(B, T, C)
) # re-assemble all head outputs side by side
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
@@ -197,16 +189,12 @@ class GPT(nn.Module):
"ln_f": nn.LayerNorm(config.gpt_hidden_dim),
}
)
self.lm_head = nn.Linear(
config.gpt_hidden_dim, config.gpt_output_dim, bias=False
)
self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False)
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(
p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)
)
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer))
# report number of parameters
n_params = sum(p.numel() for p in self.parameters())
@@ -220,17 +208,11 @@ class GPT(nn.Module):
), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
# positional encodings that are added to the input embeddings
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
0
) # shape (1, t)
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
# forward the GPT model itself
tok_emb = self.transformer.wte(
input
) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(
pos
) # position embeddings of shape (1, t, gpt_hidden_dim)
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
@@ -255,9 +237,7 @@ class GPT(nn.Module):
# but want to use a smaller block size for some smaller, simpler model
assert gpt_block_size <= self.config.gpt_block_size
self.config.gpt_block_size = gpt_block_size
self.transformer.wpe.weight = nn.Parameter(
self.transformer.wpe.weight[:gpt_block_size]
)
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size])
for block in self.transformer.h:
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
@@ -290,9 +270,7 @@ class GPT(nn.Module):
param_dict = dict(self.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
assert (
len(inter_params) == 0
), "parameters {} made it into both decay/no_decay sets!".format(
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
)
assert (
@@ -390,12 +368,8 @@ class ResidualVQ(nn.Module):
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
self.project_in = (
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
self.num_quantizers = num_quantizers
@@ -403,10 +377,7 @@ class ResidualVQ(nn.Module):
self.layers = nn.ModuleList(
[
VectorQuantize(
dim=codebook_dim,
codebook_dim=codebook_dim,
accept_image_fmap=accept_image_fmap,
**kwargs,
dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs
)
for _ in range(num_quantizers)
]
@@ -477,9 +448,7 @@ class ResidualVQ(nn.Module):
return all_codes
def forward(
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
):
def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None):
"""
For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss.
First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize.
@@ -508,17 +477,13 @@ class ResidualVQ(nn.Module):
), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
ce_losses = []
should_quantize_dropout = (
self.training and self.quantize_dropout and not return_loss
)
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
# sample a layer index at which to dropout further residual quantization
# also prepare null indices and loss
if should_quantize_dropout:
rand_quantize_dropout_index = randrange(
self.quantize_dropout_cutoff_index, num_quant
)
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = (
@@ -527,23 +492,14 @@ class ResidualVQ(nn.Module):
- 1
)
null_indices_shape = (
(x.shape[0], *x.shape[-2:])
if self.accept_image_fmap
else tuple(x.shape[:2])
)
null_indices = torch.full(
null_indices_shape, -1.0, device=device, dtype=torch.long
)
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long)
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
# go through the layers
for quantizer_index, layer in enumerate(self.layers):
if (
should_quantize_dropout
and quantizer_index > rand_quantize_dropout_index
):
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
all_indices.append(null_indices)
all_losses.append(null_loss)
continue
@@ -583,9 +539,7 @@ class ResidualVQ(nn.Module):
# stack all losses and indices
all_losses, all_indices = map(
partial(torch.stack, dim=-1), (all_losses, all_indices)
)
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))
ret = (quantized_out, all_indices, all_losses)
@@ -645,12 +599,8 @@ class VectorQuantize(nn.Module):
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
self.project_in = (
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
self.eps = eps
self.commitment_weight = commitment_weight
@@ -664,14 +614,10 @@ class VectorQuantize(nn.Module):
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
assert not (
ema_update and learnable_codebook
), "learnable codebook not compatible with EMA update"
assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update"
assert 0 <= sync_update_v <= 1.0
assert not (
sync_update_v > 0.0 and not learnable_codebook
), "learnable codebook must be turned on"
assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on"
self.sync_update_v = sync_update_v
@@ -683,9 +629,7 @@ class VectorQuantize(nn.Module):
)
if sync_codebook is None:
sync_codebook = (
distributed.is_initialized() and distributed.get_world_size() > 1
)
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1
codebook_kwargs = {
"dim": codebook_dim,
@@ -850,17 +794,11 @@ class VectorQuantize(nn.Module):
# quantize again
quantize, embed_ind, distances = self._codebook(
x, **codebook_forward_kwargs
)
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
if self.training:
# determine code to use for commitment loss
maybe_detach = (
torch.detach
if not self.learnable_codebook or freeze_codebook
else identity
)
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity
commit_quantize = maybe_detach(quantize)
@@ -870,9 +808,7 @@ class VectorQuantize(nn.Module):
if self.sync_update_v > 0.0:
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
quantize = quantize + self.sync_update_v * (
quantize - quantize.detach()
)
quantize = quantize + self.sync_update_v * (quantize - quantize.detach())
# function for calculating cross entropy loss to distance matrix
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
@@ -905,9 +841,7 @@ class VectorQuantize(nn.Module):
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
if self.accept_image_fmap:
embed_ind = rearrange(
embed_ind, "b (h w) ... -> b h w ...", h=height, w=width
)
embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width)
if only_one:
embed_ind = rearrange(embed_ind, "b 1 -> b")
@@ -961,12 +895,8 @@ class VectorQuantize(nn.Module):
num_codes = codebook.shape[-2]
if (
self.orthogonal_reg_max_codes is not None
) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[
: self.orthogonal_reg_max_codes
]
if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes]
codebook = codebook[:, rand_ids]
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
@@ -998,9 +928,7 @@ class VectorQuantize(nn.Module):
# if masking, only return quantized for where mask has True
if mask is not None:
quantize = torch.where(
rearrange(mask, "... -> ... 1"), quantize, orig_input
)
quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input)
return quantize, embed_ind, loss
@@ -1110,9 +1038,7 @@ def sample_vectors(samples, num):
def batched_sample_vectors(samples, num):
return torch.stack(
[sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0
)
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0)
def pad_shape(shape, size, dim=0):
@@ -1163,9 +1089,7 @@ def sample_vectors_distributed(local_samples, num):
all_num_samples = all_gather_sizes(local_samples, dim=0)
if rank == 0:
samples_per_rank = sample_multinomial(
num, all_num_samples / all_num_samples.sum()
)
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
else:
samples_per_rank = torch.empty_like(all_num_samples)
@@ -1278,9 +1202,7 @@ class EuclideanCodebook(nn.Module):
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code
self.reset_cluster_size = (
reset_cluster_size
if (reset_cluster_size is not None)
else threshold_ema_dead_code
reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code
)
assert callable(gumbel_sample)
@@ -1291,14 +1213,8 @@ class EuclideanCodebook(nn.Module):
use_ddp and num_codebooks > 1 and kmeans_init
), "kmeans init is not compatible with multiple codebooks in distributed environment for now"
self.sample_fn = (
sample_vectors_distributed
if use_ddp and sync_kmeans
else batched_sample_vectors
)
self.kmeans_all_reduce_fn = (
distributed.all_reduce if use_ddp and sync_kmeans else noop
)
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
self.register_buffer("initted", torch.Tensor([not kmeans_init]))
@@ -1437,9 +1353,7 @@ class EuclideanCodebook(nn.Module):
distributed.all_reduce(variance_numer)
batch_variance = variance_numer / num_vectors
self.update_with_decay(
"batch_variance", batch_variance, self.affine_param_batch_decay
)
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
def replace(self, batch_samples, batch_mask):
for ind, (samples, mask) in enumerate(
@@ -1448,9 +1362,7 @@ class EuclideanCodebook(nn.Module):
if not torch.any(mask):
continue
sampled = self.sample_fn(
rearrange(samples, "... -> 1 ..."), mask.sum().item()
)
sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item())
sampled = rearrange(sampled, "1 ... -> ...")
self.embed.data[ind][mask] = sampled
@@ -1474,9 +1386,7 @@ class EuclideanCodebook(nn.Module):
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
needs_codebook_dim = x.ndim < 4
sample_codebook_temp = (
sample_codebook_temp
if (sample_codebook_temp is not None)
else self.sample_codebook_temp
sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
)
x = x.float()
@@ -1504,9 +1414,7 @@ class EuclideanCodebook(nn.Module):
if self.affine_param:
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
embed = (embed - self.codebook_mean) * (
batch_std / codebook_std
) + self.batch_mean
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
dist = -cdist(flatten, embed)
@@ -1524,9 +1432,7 @@ class EuclideanCodebook(nn.Module):
if self.training and self.ema_update and not freeze_codebook:
if self.affine_param:
flatten = (flatten - self.batch_mean) * (
codebook_std / batch_std
) + self.codebook_mean
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
if mask is not None:
embed_onehot[~mask] = 0.0
@@ -1549,9 +1455,7 @@ class EuclideanCodebook(nn.Module):
self.expire_codes_(x)
if needs_codebook_dim:
quantize, embed_ind = tuple(
rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)
)
quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind))
dist = unpack_one(dist, ps, "h * d")

View File

@@ -65,9 +65,7 @@ def save_image(img_array, serial_number, frame_index, images_dir):
img.save(str(path), quality=100)
logging.info(f"Saved image: {path}")
except Exception as e:
logging.error(
f"Failed to save image for camera {serial_number} frame {frame_index}: {e}"
)
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
def save_images_from_cameras(
@@ -96,9 +94,7 @@ def save_images_from_cameras(
cameras = []
for cam_sn in serial_numbers:
print(f"{cam_sn=}")
camera = IntelRealSenseCamera(
cam_sn, fps=fps, width=width, height=height, mock=mock
)
camera = IntelRealSenseCamera(cam_sn, fps=fps, width=width, height=height, mock=mock)
camera.connect()
print(
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
@@ -144,9 +140,7 @@ def save_images_from_cameras(
if time.perf_counter() - start_time > record_time_s:
break
print(
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
)
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
frame_index += 1
finally:
@@ -174,7 +168,6 @@ class IntelRealSenseCameraConfig:
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
use_depth: bool = False
force_hardware_reset: bool = True
rotation: int | None = None
@@ -186,14 +179,8 @@ class IntelRealSenseCameraConfig:
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
at_least_one_is_not_none = (
self.fps is not None or self.width is not None or self.height is not None
)
at_least_one_is_none = (
self.fps is None or self.width is None or self.height is None
)
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
if at_least_one_is_not_none and at_least_one_is_none:
raise ValueError(
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
@@ -201,9 +188,7 @@ class IntelRealSenseCameraConfig:
)
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
)
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
class IntelRealSenseCamera:
@@ -269,7 +254,6 @@ class IntelRealSenseCamera:
self.fps = config.fps
self.width = config.width
self.height = config.height
self.channels = config.channels
self.color_mode = config.color_mode
self.use_depth = config.use_depth
self.force_hardware_reset = config.force_hardware_reset
@@ -298,9 +282,7 @@ class IntelRealSenseCamera:
self.rotation = cv2.ROTATE_180
@classmethod
def init_from_name(
cls, name: str, config: IntelRealSenseCameraConfig | None = None, **kwargs
):
def init_from_name(cls, name: str, config: IntelRealSenseCameraConfig | None = None, **kwargs):
camera_infos = find_cameras()
camera_names = [cam["name"] for cam in camera_infos]
this_name_count = Counter(camera_names)[name]
@@ -310,9 +292,7 @@ class IntelRealSenseCamera:
f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them."
)
name_to_serial_dict = {
cam["name"]: cam["serial_number"] for cam in camera_infos
}
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
cam_sn = name_to_serial_dict[name]
if config is None:
@@ -339,17 +319,13 @@ class IntelRealSenseCamera:
if self.fps and self.width and self.height:
# TODO(rcadene): can we set rgb8 directly?
config.enable_stream(
rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps
)
config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps)
else:
config.enable_stream(rs.stream.color)
if self.use_depth:
if self.fps and self.width and self.height:
config.enable_stream(
rs.stream.depth, self.width, self.height, rs.format.z16, self.fps
)
config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps)
else:
config.enable_stream(rs.stream.depth)
@@ -382,9 +358,7 @@ class IntelRealSenseCamera:
actual_height = color_profile.height()
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose(
self.fps, actual_fps, rel_tol=1e-3
):
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
# Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError(
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
@@ -404,9 +378,7 @@ class IntelRealSenseCamera:
self.is_connected = True
def read(
self, temporary_color: str | None = None
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3)
of type `np.uint8`, contrarily to the pytorch format which is float channel first.
@@ -433,15 +405,11 @@ class IntelRealSenseCamera:
color_frame = frame.get_color_frame()
if not color_frame:
raise OSError(
f"Can't capture color image from IntelRealSenseCamera({self.serial_number})."
)
raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).")
color_image = np.asanyarray(color_frame.get_data())
requested_color_mode = (
self.color_mode if temporary_color is None else temporary_color
)
requested_color_mode = self.color_mode if temporary_color is None else temporary_color
if requested_color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
@@ -469,9 +437,7 @@ class IntelRealSenseCamera:
if self.use_depth:
depth_frame = frame.get_depth_frame()
if not depth_frame:
raise OSError(
f"Can't capture depth image from IntelRealSenseCamera({self.serial_number})."
)
raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).")
depth_map = np.asanyarray(depth_frame.get_data())
@@ -513,9 +479,7 @@ class IntelRealSenseCamera:
# TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here
num_tries += 1
time.sleep(1 / self.fps)
if num_tries > self.fps and (
self.thread.ident is None or not self.thread.is_alive()
):
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
raise Exception(
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
)

View File

@@ -31,14 +31,10 @@ from lerobot.common.utils.utils import capture_timestamp_utc
MAX_OPENCV_INDEX = 60
def find_cameras(
raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False
) -> list[dict]:
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
cameras = []
if platform.system() == "Linux":
print(
"Linux detected. Finding available camera indices through scanning '/dev/video*' ports"
)
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
possible_ports = [str(port) for port in Path("/dev").glob("video*")]
ports = _find_cameras(possible_ports, mock=mock)
for port in ports:
@@ -169,9 +165,7 @@ def save_images_from_cameras(
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
print(
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
)
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
if time.perf_counter() - start_time > record_time_s:
break
@@ -198,7 +192,6 @@ class OpenCVCameraConfig:
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
rotation: int | None = None
mock: bool = False
@@ -208,12 +201,8 @@ class OpenCVCameraConfig:
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
)
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
class OpenCVCamera:
@@ -255,12 +244,7 @@ class OpenCVCamera:
```
"""
def __init__(
self,
camera_index: int | str,
config: OpenCVCameraConfig | None = None,
**kwargs,
):
def __init__(self, camera_index: int | str, config: OpenCVCameraConfig | None = None, **kwargs):
if config is None:
config = OpenCVCameraConfig()
@@ -274,21 +258,16 @@ class OpenCVCamera:
if platform.system() == "Linux":
if isinstance(self.camera_index, int):
self.port = Path(f"/dev/video{self.camera_index}")
elif isinstance(self.camera_index, str) and is_valid_unix_path(
self.camera_index
):
elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index):
self.port = Path(self.camera_index)
# Retrieve the camera index from a potentially symlinked path
self.camera_index = get_camera_index_from_unix_port(self.port)
else:
raise ValueError(
f"Please check the provided camera_index: {camera_index}"
)
raise ValueError(f"Please check the provided camera_index: {camera_index}")
self.fps = config.fps
self.width = config.width
self.height = config.height
self.channels = config.channels
self.color_mode = config.color_mode
self.mock = config.mock
@@ -315,9 +294,7 @@ class OpenCVCamera:
def connect(self):
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(
f"OpenCVCamera({self.camera_index}) is already connected."
)
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
if self.mock:
import tests.mock_cv2 as cv2
@@ -328,11 +305,7 @@ class OpenCVCamera:
# when other threads are used to save the images.
cv2.setNumThreads(1)
camera_idx = (
f"/dev/video{self.camera_index}"
if platform.system() == "Linux"
else self.camera_index
)
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
# First create a temporary camera trying to access `camera_index`,
# and verify it is a valid camera by calling `isOpened`.
tmp_camera = cv2.VideoCapture(camera_idx)
@@ -372,22 +345,16 @@ class OpenCVCamera:
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose(
self.fps, actual_fps, rel_tol=1e-3
):
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
# Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError(
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
)
if self.width is not None and not math.isclose(
self.width, actual_width, rel_tol=1e-3
):
if self.width is not None and not math.isclose(self.width, actual_width, rel_tol=1e-3):
raise OSError(
f"Can't set {self.width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
)
if self.height is not None and not math.isclose(
self.height, actual_height, rel_tol=1e-3
):
if self.height is not None and not math.isclose(self.height, actual_height, rel_tol=1e-3):
raise OSError(
f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
)
@@ -417,9 +384,7 @@ class OpenCVCamera:
if not ret:
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
requested_color_mode = (
self.color_mode if temporary_color_mode is None else temporary_color_mode
)
requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
if requested_color_mode not in ["rgb", "bgr"]:
raise ValueError(

View File

@@ -11,29 +11,19 @@ from copy import copy
from functools import cache
import cv2
import numpy as np
import torch
import tqdm
from deepdiff import DeepDiff
from termcolor import colored
from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot
from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_writer
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import (
get_safe_torch_device,
init_hydra_config,
set_global_seed,
)
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
from lerobot.scripts.eval import get_pretrained_policy_path
def log_control_info(
robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None
):
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
log_items = []
if episode_index is not None:
log_items.append(f"ep:{episode_index}")
@@ -42,7 +32,7 @@ def log_control_info(
def log_dt(shortname, dt_val_s):
nonlocal log_items, fps
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
if fps is not None:
actual_fps = 1 / dt_val_s
if actual_fps < fps - 1:
@@ -104,9 +94,7 @@ def predict_action(observation, policy, device, use_amp):
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type)
if device.type == "cuda" and use_amp
else nullcontext(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
@@ -129,22 +117,14 @@ def predict_action(observation, policy, device, use_amp):
return action
def init_keyboard_listener(assign_rewards=False):
"""
Initializes a keyboard listener to enable early termination of an episode
or environment reset by pressing the right arrow key ('->'). This may require
sudo permissions to allow the terminal to monitor keyboard events.
Args:
assign_rewards (bool): If True, allows annotating the collected trajectory
with a binary reward at the end of the episode to indicate success.
"""
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if assign_rewards:
events["next.reward"] = 0
if is_headless():
logging.warning(
@@ -162,22 +142,13 @@ def init_keyboard_listener(assign_rewards=False):
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print(
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
)
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
elif assign_rewards and key == keyboard.Key.space:
events["next.reward"] = 1 if events["next.reward"] == 0 else 0
print(
"Space key pressed. Assigning new reward to the subsequent frames. New reward:",
events["next.reward"],
)
except Exception as e:
print(f"Error handling key press: {e}")
@@ -190,12 +161,8 @@ def init_keyboard_listener(assign_rewards=False):
def init_policy(pretrained_policy_name_or_path, policy_overrides):
"""Instantiate the policy and load fps, device and use_amp from config yaml"""
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
hydra_cfg = init_hydra_config(
pretrained_policy_path / "config.yaml", policy_overrides
)
policy = make_policy(
hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path
)
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
@@ -214,7 +181,7 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides):
def warmup_record(
robot,
events,
enable_teleoperation,
enable_teloperation,
warmup_time_s,
display_cameras,
fps,
@@ -225,7 +192,7 @@ def warmup_record(
display_cameras=display_cameras,
events=events,
fps=fps,
teleoperate=enable_teleoperation,
teleoperate=enable_teloperation,
)
@@ -239,7 +206,6 @@ def record_episode(
device,
use_amp,
fps,
record_delta_actions,
):
control_loop(
robot=robot,
@@ -251,7 +217,6 @@ def record_episode(
device=device,
use_amp=use_amp,
fps=fps,
record_delta_actions=record_delta_actions,
teleoperate=policy is None,
)
@@ -262,13 +227,12 @@ def control_loop(
control_time_s=None,
teleoperate=False,
display_cameras=False,
dataset: LeRobotDataset | None = None,
dataset=None,
events=None,
policy=None,
device=None,
use_amp=None,
fps=None,
record_delta_actions=False,
):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
@@ -283,22 +247,16 @@ def control_loop(
if teleoperate and policy is not None:
raise ValueError("When `teleoperate` is True, `policy` should be None.")
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(
f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps})."
)
if dataset is not None and fps is not None and dataset["fps"] != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < control_time_s:
start_loop_t = time.perf_counter()
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
if teleoperate:
observation, action = robot.teleop_step(record_data=True)
if record_delta_actions:
action["action"] = action["action"] - current_joint_positions
else:
observation = robot.capture_observation()
@@ -310,23 +268,12 @@ def control_loop(
action = {"action": action}
if dataset is not None:
frame = {**observation, **action}
if "next.reward" in events:
frame["next.reward"] = events["next.reward"]
frame["next.done"] = (events["next.reward"] == 1) or (
events["exit_early"]
)
dataset.add_frame(frame)
# if frame["next.done"]:
# break
add_frame(dataset, observation, action)
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
if fps is not None:
@@ -350,8 +297,6 @@ def reset_environment(robot, events, reset_time_s):
timestamp = 0
start_vencod_t = time.perf_counter()
if "next.reward" in events:
events["next.reward"] = 0
# Wait if necessary
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
@@ -364,16 +309,6 @@ def reset_environment(robot, events, reset_time_s):
break
def reset_follower_position(robot: Robot, target_position):
current_position = robot.follower_arms["main"].read("Present_Position")
trajectory = torch.from_numpy(
np.linspace(current_position, target_position, 50)
) # NOTE: 30 is just an aribtrary number
for pose in trajectory:
robot.send_action(pose)
busy_wait(0.015)
def stop_recording(robot, listener, display_cameras):
robot.disconnect()
@@ -389,47 +324,7 @@ def sanity_check_dataset_name(repo_id, policy):
_, dataset_name = repo_id.split("/")
# either repo_id doesnt start with "eval_" and there is no policy
# or repo_id starts with "eval_" and there is a policy
# Check if dataset_name starts with "eval_" but policy is missing
if dataset_name.startswith("eval_") and policy is None:
if dataset_name.startswith("eval_") == (policy is None):
raise ValueError(
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
)
# Check if dataset_name does not start with "eval_" but policy is provided
if not dataset_name.startswith("eval_") and policy is not None:
raise ValueError(
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy})."
)
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset,
robot: Robot,
fps: int,
use_videos: bool,
extra_features: dict = None,
) -> None:
features_from_robot = get_features_from_robot(robot, use_videos)
if extra_features is not None:
features_from_robot.update(extra_features)
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),
("features", dataset.features, features_from_robot),
]
mismatches = []
for field, dataset_value, present_value in fields:
diff = DeepDiff(
dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]
)
if diff:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
if mismatches:
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n"
+ "\n".join(mismatches)
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
)

View File

@@ -8,10 +8,7 @@ from copy import deepcopy
import numpy as np
import tqdm
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.utils.utils import capture_timestamp_utc
PROTOCOL_VERSION = 2.0
@@ -146,9 +143,7 @@ NUM_READ_RETRY = 10
NUM_WRITE_RETRY = 10
def convert_degrees_to_steps(
degrees: float | np.ndarray, models: str | list[str]
) -> np.ndarray:
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
@@ -383,9 +378,7 @@ class DynamixelMotorsBus:
indices = []
for idx in tqdm.tqdm(possible_ids):
try:
present_idx = self.read_with_motor_ids(
self.motor_models, [idx], "ID", num_retry=num_retry
)[0]
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
except ConnectionError:
continue
@@ -401,9 +394,7 @@ class DynamixelMotorsBus:
def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate:
print(
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
)
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate:
@@ -424,9 +415,7 @@ class DynamixelMotorsBus:
def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration
def apply_calibration_autocorrect(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
"""This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
@@ -439,9 +428,7 @@ class DynamixelMotorsBus:
values = self.apply_calibration(values, motor_names)
return values
def apply_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree.
@@ -516,9 +503,7 @@ class DynamixelMotorsBus:
return values
def autocorrect_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration.
@@ -560,23 +545,15 @@ class DynamixelMotorsBus:
values[i] *= -1
# Convert from initial range to range [-180, 180] degrees
calib_val = (
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
)
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
calib_val < UPPER_BOUND_DEGREE
)
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (
-(resolution // 2) - values[i] - homing_offset
) / resolution
upp_factor = (
(resolution // 2) - values[i] - homing_offset
) / resolution
low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution
upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
start_pos = self.calibration["start_pos"][calib_idx]
@@ -584,9 +561,7 @@ class DynamixelMotorsBus:
# Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
calib_val < UPPER_BOUND_LINEAR
)
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
# Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
@@ -602,27 +577,19 @@ class DynamixelMotorsBus:
factor = math.ceil(low_factor)
if factor > upp_factor:
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
else:
factor = math.ceil(upp_factor)
if factor > low_factor:
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
in_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
@@ -632,9 +599,7 @@ class DynamixelMotorsBus:
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""Inverse of `apply_calibration`."""
if motor_names is None:
motor_names = self.motor_names
@@ -673,9 +638,7 @@ class DynamixelMotorsBus:
values = np.round(values).astype(np.int32)
return values
def read_with_motor_ids(
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
):
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock:
import tests.mock_dynamixel_sdk as dxl
else:
@@ -777,9 +740,7 @@ class DynamixelMotorsBus:
values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name(
"delta_timestamp_s", "read", data_name, motor_names
)
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received
@@ -788,9 +749,7 @@ class DynamixelMotorsBus:
return values
def write_with_motor_ids(
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
):
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock:
import tests.mock_dynamixel_sdk as dxl
else:
@@ -819,12 +778,7 @@ class DynamixelMotorsBus:
f"{self.packet_handler.getTxRxResult(comm)}"
)
def write(
self,
data_name,
values: int | float | np.ndarray,
motor_names: str | list[str] | None = None,
):
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
@@ -885,9 +839,7 @@ class DynamixelMotorsBus:
)
# log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name(
"delta_timestamp_s", "write", data_name, motor_names
)
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command?

View File

@@ -8,10 +8,7 @@ from copy import deepcopy
import numpy as np
import tqdm
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.utils.utils import capture_timestamp_utc
PROTOCOL_VERSION = 0
@@ -125,9 +122,7 @@ NUM_READ_RETRY = 20
NUM_WRITE_RETRY = 20
def convert_degrees_to_steps(
degrees: float | np.ndarray, models: str | list[str]
) -> np.ndarray:
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
@@ -363,9 +358,7 @@ class FeetechMotorsBus:
indices = []
for idx in tqdm.tqdm(possible_ids):
try:
present_idx = self.read_with_motor_ids(
self.motor_models, [idx], "ID", num_retry=num_retry
)[0]
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
except ConnectionError:
continue
@@ -381,9 +374,7 @@ class FeetechMotorsBus:
def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate:
print(
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
)
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate:
@@ -404,9 +395,7 @@ class FeetechMotorsBus:
def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration
def apply_calibration_autocorrect(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
"""This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
@@ -419,9 +408,7 @@ class FeetechMotorsBus:
values = self.apply_calibration(values, motor_names)
return values
def apply_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree.
@@ -495,9 +482,7 @@ class FeetechMotorsBus:
return values
def autocorrect_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration.
@@ -536,26 +521,18 @@ class FeetechMotorsBus:
values[i] *= -1
# Convert from initial range to range [-180, 180] degrees
calib_val = (
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
)
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
calib_val < UPPER_BOUND_DEGREE
)
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
- values[i]
- homing_offset
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
) / resolution
upp_factor = (
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
- values[i]
- homing_offset
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
@@ -564,9 +541,7 @@ class FeetechMotorsBus:
# Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
calib_val < UPPER_BOUND_LINEAR
)
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
# Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
@@ -582,27 +557,19 @@ class FeetechMotorsBus:
factor = math.ceil(low_factor)
if factor > upp_factor:
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
else:
factor = math.ceil(upp_factor)
if factor > low_factor:
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
in_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
@@ -612,9 +579,7 @@ class FeetechMotorsBus:
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""Inverse of `apply_calibration`."""
if motor_names is None:
motor_names = self.motor_names
@@ -690,9 +655,7 @@ class FeetechMotorsBus:
return values
def read_with_motor_ids(
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
):
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock:
import tests.mock_scservo_sdk as scs
else:
@@ -797,9 +760,7 @@ class FeetechMotorsBus:
values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name(
"delta_timestamp_s", "read", data_name, motor_names
)
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received
@@ -808,9 +769,7 @@ class FeetechMotorsBus:
return values
def write_with_motor_ids(
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
):
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock:
import tests.mock_scservo_sdk as scs
else:
@@ -839,12 +798,7 @@ class FeetechMotorsBus:
f"{self.packet_handler.getTxRxResult(comm)}"
)
def write(
self,
data_name,
values: int | float | np.ndarray,
motor_names: str | list[str] | None = None,
):
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
@@ -905,9 +859,7 @@ class FeetechMotorsBus:
)
# log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name(
"delta_timestamp_s", "write", data_name, motor_names
)
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command?

View File

@@ -10,7 +10,9 @@ from lerobot.common.robot_devices.motors.dynamixel import (
)
from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
URL_TEMPLATE = (
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
# The following positions are provided in nominal degree range ]-180, +180[
# For more info on these constants, see comments in the code where they get used.
@@ -21,9 +23,7 @@ ROTATED_POSITION_DEGREE = 90
def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError(
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
)
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
def apply_drive_mode(position, drive_mode):
@@ -64,16 +64,12 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
```
"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
raise ValueError("To run calibration, the torque must be disabled on all motors.")
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to zero position")
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
input("Press Enter to continue...")
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
@@ -94,15 +90,10 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position")
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps(
ROTATED_POSITION_DEGREE, arm.motor_models
)
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
# Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
@@ -111,15 +102,11 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# Re-compute homing offset to take into account drive mode
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
rotated_nearest_pos = compute_nearest_rounded_position(
rotated_drived_pos, arm.motor_models
)
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models)
homing_offset = rotated_target_pos - rotated_nearest_pos
print("\nMove arm to rest position")
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
input("Press Enter to continue...")
print()

View File

@@ -12,7 +12,9 @@ from lerobot.common.robot_devices.motors.feetech import (
)
from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
URL_TEMPLATE = (
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
# The following positions are provided in nominal degree range ]-180, +180[
# For more info on these constants, see comments in the code where they get used.
@@ -23,9 +25,7 @@ ROTATED_POSITION_DEGREE = 90
def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError(
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
)
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
def apply_drive_mode(position, drive_mode):
@@ -126,9 +126,7 @@ def apply_offset(calib, offset):
return calib
def run_arm_auto_calibration(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
if robot_type == "so100":
return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type)
elif robot_type == "moss":
@@ -137,27 +135,18 @@ def run_arm_auto_calibration(
raise ValueError(robot_type)
def run_arm_auto_calibration_so100(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
raise ValueError("To run calibration, the torque must be disabled on all motors.")
if not (robot_type == "so100" and arm_type == "follower"):
raise NotImplementedError(
"Auto calibration only supports the follower of so100 arms for now."
)
raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.")
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position")
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254])
@@ -204,16 +193,11 @@ def run_arm_auto_calibration_so100(
print("Calibrate elbow_flex")
calib["elbow_flex"] = move_to_calibrate(
arm,
"elbow_flex",
positive_first=False,
in_between_move_hook=in_between_move_hook,
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook
)
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
arm.write(
"Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex"
)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
time.sleep(1)
def in_between_move_hook():
@@ -241,30 +225,18 @@ def run_arm_auto_calibration_so100(
}
arm.write("Goal_Position", list(positions.values()), list(positions.keys()))
arm.write(
"Goal_Position",
round(calib["shoulder_lift"]["zero_pos"] - 1600),
"shoulder_lift",
)
arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift")
time.sleep(2)
arm.write(
"Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex"
)
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
time.sleep(2)
arm.write(
"Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex"
)
arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex")
time.sleep(2)
arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper")
time.sleep(2)
print("Calibrate wrist_roll")
calib["wrist_roll"] = move_to_calibrate(
arm,
"wrist_roll",
invert_drive_mode=True,
positive_first=False,
while_move_hook=while_move_hook,
arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook
)
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
@@ -274,9 +246,7 @@ def run_arm_auto_calibration_so100(
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
time.sleep(1)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex")
arm.write(
"Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift"
)
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift")
time.sleep(1)
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
time.sleep(1)
@@ -305,27 +275,18 @@ def run_arm_auto_calibration_so100(
return calib_dict
def run_arm_auto_calibration_moss(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
raise ValueError("To run calibration, the torque must be disabled on all motors.")
if not (robot_type == "moss" and arm_type == "follower"):
raise NotImplementedError(
"Auto calibration only supports the follower of moss arms for now."
)
raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.")
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position")
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254])
@@ -409,12 +370,8 @@ def run_arm_auto_calibration_moss(
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
time.sleep(1)
arm.write(
"Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift"
)
arm.write(
"Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex"
)
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift")
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex")
time.sleep(2)
calib_modes = []
@@ -441,9 +398,7 @@ def run_arm_auto_calibration_moss(
return calib_dict
def run_arm_manual_calibration(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
"""This function ensures that a neural network trained on data collected on a given robot
can work on another robot. For instance before calibration, setting a same goal position
for each motor of two different robots will get two very different positions. But after calibration,
@@ -466,16 +421,12 @@ def run_arm_manual_calibration(
```
"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
raise ValueError("To run calibration, the torque must be disabled on all motors.")
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to zero position")
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
input("Press Enter to continue...")
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
@@ -495,15 +446,10 @@ def run_arm_manual_calibration(
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position")
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps(
ROTATED_POSITION_DEGREE, arm.motor_models
)
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
# Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
@@ -515,9 +461,7 @@ def run_arm_manual_calibration(
homing_offset = rotated_target_pos - rotated_drived_pos
print("\nMove arm to rest position")
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
)
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
input("Press Enter to continue...")
print()

View File

@@ -18,16 +18,11 @@ import torch
from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.motors.utils import MotorsBus
from lerobot.common.robot_devices.robots.utils import get_arm_id
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
def ensure_safe_goal_position(
goal_pos: torch.Tensor,
present_pos: torch.Tensor,
max_relative_target: float | list[float],
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
):
# Cap relative action target magnitude for safety.
diff = goal_pos - present_pos
@@ -37,7 +32,7 @@ def ensure_safe_goal_position(
safe_goal_pos = present_pos + safe_diff
if not torch.allclose(goal_pos, safe_goal_pos):
logging.debug(
logging.warning(
"Relative goal position magnitude had to be clamped to be safe.\n"
f" requested relative goal position target: {diff}\n"
f" clamped relative goal position target: {safe_diff}"
@@ -72,14 +67,8 @@ class ManipulatorRobotConfig:
# gripper is not put in torque mode.
gripper_open_degree: float | None = None
joint_position_relative_bounds: dict[np.ndarray] | None = None
def __setattr__(self, prop: str, val):
if (
prop == "max_relative_target"
and val is not None
and isinstance(val, Sequence)
):
if prop == "max_relative_target" and val is not None and isinstance(val, Sequence):
for name in self.follower_arms:
if len(self.follower_arms[name].motors) != len(val):
raise ValueError(
@@ -89,16 +78,11 @@ class ManipulatorRobotConfig:
"Note: This feature does not yet work with robots where different follower arms have "
"different numbers of motors."
)
if prop == "joint_position_relative_bounds" and val is not None:
for key in val:
val[key] = torch.tensor(val[key])
super().__setattr__(prop, val)
def __post_init__(self):
if self.robot_type not in ["koch", "koch_bimanual", "aloha", "so100", "moss"]:
raise ValueError(
f"Provided robot type ({self.robot_type}) is not supported."
)
raise ValueError(f"Provided robot type ({self.robot_type}) is not supported.")
class ManipulatorRobot:
@@ -242,42 +226,6 @@ class ManipulatorRobot:
self.is_connected = False
self.logs = {}
def get_motor_names(self, arm: dict[str, MotorsBus]) -> list:
return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]
@property
def camera_features(self) -> dict:
cam_ft = {}
for cam_key, cam in self.cameras.items():
key = f"observation.images.{cam_key}"
cam_ft[key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
@property
def motor_features(self) -> dict:
action_names = self.get_motor_names(self.leader_arms)
state_names = self.get_motor_names(self.leader_arms)
return {
"action": {
"dtype": "float32",
"shape": (len(action_names),),
"names": action_names,
},
"observation.state": {
"dtype": "float32",
"shape": (len(state_names),),
"names": state_names,
},
}
@property
def features(self):
return {**self.motor_features, **self.camera_features}
@property
def has_camera(self):
return len(self.cameras) > 0
@@ -352,9 +300,7 @@ class ManipulatorRobot:
# to squeeze the gripper and have it spring back to an open position on its own.
for name in self.leader_arms:
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write(
"Goal_Position", self.config.gripper_open_degree, "gripper"
)
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
# Check both arms can be read
for name in self.follower_arms:
@@ -386,26 +332,18 @@ class ManipulatorRobot:
print(f"Missing calibration file '{arm_calib_path}'")
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
from lerobot.common.robot_devices.robots.dynamixel_calibration import (
run_arm_calibration,
)
from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration
calibration = run_arm_calibration(
arm, self.robot_type, name, arm_type
)
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
elif self.robot_type in ["so100", "moss"]:
from lerobot.common.robot_devices.robots.feetech_calibration import (
run_arm_manual_calibration,
)
calibration = run_arm_manual_calibration(
arm, self.robot_type, name, arm_type
)
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
print(
f"Calibration is done! Saving calibration file '{arm_calib_path}'"
)
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f:
json.dump(calibration, f)
@@ -424,17 +362,13 @@ class ManipulatorRobot:
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError(
"To run set robot preset, the torque must be disabled on all motors."
)
raise ValueError("To run set robot preset, the torque must be disabled on all motors.")
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [
name for name in arm.motor_names if name != "gripper"
]
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
if len(all_motors_except_gripper) > 0:
# 4 corresponds to Extended Position on Koch motors
arm.write("Operating_Mode", 4, all_motors_except_gripper)
@@ -463,9 +397,7 @@ class ManipulatorRobot:
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
# so that we can use it as a trigger to close the gripper of the follower arms.
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write(
"Goal_Position", self.config.gripper_open_degree, "gripper"
)
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
def set_aloha_robot_preset(self):
def set_shadow_(arm):
@@ -495,15 +427,11 @@ class ManipulatorRobot:
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [
name
for name in self.follower_arms[name].motor_names
if name != "gripper"
name for name in self.follower_arms[name].motor_names if name != "gripper"
]
if len(all_motors_except_gripper) > 0:
# 4 corresponds to Extended Position on Aloha motors
self.follower_arms[name].write(
"Operating_Mode", 4, all_motors_except_gripper
)
self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper)
# Use 'position control current based' for follower gripper to be limited by the limit of the current.
# It can grasp an object without forcing too much even tho,
@@ -551,9 +479,7 @@ class ManipulatorRobot:
before_lread_t = time.perf_counter()
leader_pos[name] = self.leader_arms[name].read("Present_Position")
leader_pos[name] = torch.from_numpy(leader_pos[name])
self.logs[f"read_leader_{name}_pos_dt_s"] = (
time.perf_counter() - before_lread_t
)
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
# Send goal position to the follower
follower_goal_pos = {}
@@ -561,31 +487,19 @@ class ManipulatorRobot:
before_fwrite_t = time.perf_counter()
goal_pos = leader_pos[name]
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
if self.config.joint_position_relative_bounds is not None:
goal_pos = torch.clamp(
goal_pos,
self.config.joint_position_relative_bounds["min"],
self.config.joint_position_relative_bounds["max"],
)
# Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
present_pos = self.follower_arms[name].read("Present_Position")
present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position(
goal_pos, present_pos, self.config.max_relative_target
)
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
# Used when record_data=True
follower_goal_pos[name] = goal_pos
goal_pos = goal_pos.numpy().astype(np.int32)
self.follower_arms[name].write("Goal_Position", goal_pos)
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = (
time.perf_counter() - before_fwrite_t
)
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
# Early exit when recording data is not requested
if not record_data:
@@ -598,9 +512,7 @@ class ManipulatorRobot:
before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position")
follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = (
time.perf_counter() - before_fread_t
)
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
# Create state by concatenating follower current position
state = []
@@ -622,12 +534,8 @@ class ManipulatorRobot:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionnaries
obs_dict, action_dict = {}, {}
@@ -651,9 +559,7 @@ class ManipulatorRobot:
before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position")
follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = (
time.perf_counter() - before_fread_t
)
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
# Create state by concatenating follower current position
state = []
@@ -668,12 +574,8 @@ class ManipulatorRobot:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionnaries and format to pytorch
obs_dict = {}
@@ -706,29 +608,18 @@ class ManipulatorRobot:
goal_pos = action[from_idx:to_idx]
from_idx = to_idx
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
if self.config.joint_position_relative_bounds is not None:
goal_pos = torch.clamp(
goal_pos,
self.config.joint_position_relative_bounds["min"],
self.config.joint_position_relative_bounds["max"],
)
# Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
present_pos = self.follower_arms[name].read("Present_Position")
present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position(
goal_pos, present_pos, self.config.max_relative_target
)
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
# Save tensor to concat and return
action_sent.append(goal_pos)
# Send goal position to each follower
goal_pos = goal_pos.numpy().astype(np.int32)
self.follower_arms[name].write("Goal_Position", goal_pos)
return torch.cat(action_sent)

View File

@@ -60,9 +60,7 @@ class StretchRobot(StretchAPI):
def connect(self) -> None:
self.is_connected = self.startup()
if not self.is_connected:
print(
"Another process is already using Stretch. Try running 'stretch_free_robot_process.py'"
)
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
raise ConnectionError()
for name in self.cameras:
@@ -70,9 +68,7 @@ class StretchRobot(StretchAPI):
self.is_connected = self.is_connected and self.cameras[name].is_connected
if not self.is_connected:
print(
"Could not connect to the cameras, check that all cameras are plugged-in."
)
print("Could not connect to the cameras, check that all cameras are plugged-in.")
raise ConnectionError()
self.run_calibration()
@@ -117,12 +113,8 @@ class StretchRobot(StretchAPI):
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionnaries
obs_dict, action_dict = {}, {}
@@ -166,12 +158,8 @@ class StretchRobot(StretchAPI):
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionnaries
obs_dict = {}

View File

@@ -11,7 +11,6 @@ def get_arm_id(name, arm_type):
class Robot(Protocol):
# TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes
robot_type: str
features: dict
def connect(self): ...
def run_calibration(self): ...

View File

@@ -34,8 +34,7 @@ class RobotDeviceNotConnectedError(Exception):
"""Exception raised when the robot device is not connected."""
def __init__(
self,
message="This robot device is not connected. Try calling `robot_device.connect()` first.",
self, message="This robot device is not connected. Try calling `robot_device.connect()` first."
):
self.message = message
super().__init__(self.message)

View File

@@ -17,9 +17,7 @@ import importlib
import logging
def is_package_available(
pkg_name: str, return_version: bool = False
) -> tuple[bool, str] | bool:
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
Check if the package spec exists and grab its version to avoid importing a local directory.
**Note:** this doesn't work for all packages.

View File

@@ -22,8 +22,6 @@ def write_video(video_path, stacked_frames, fps):
# Filter out DeprecationWarnings raised from pkg_resources
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
"pkg_resources is deprecated as an API",
category=DeprecationWarning,
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
)
imageio.mimsave(video_path, stacked_frames, fps=fps)

View File

@@ -18,7 +18,6 @@ import os
import os.path as osp
import platform
import random
import time
from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path
@@ -116,11 +115,11 @@ def seeded_context(seed: int) -> Generator[None, None, None]:
set_global_random_state(random_state_dict)
def init_logging(log_file=None):
def init_logging():
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}"
message = f"{record.levelname} [PID: {os.getpid()}] {dt} {fnameline[-15:]:>15} {record.msg}"
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
return message
logging.basicConfig(level=logging.INFO)
@@ -134,12 +133,6 @@ def init_logging(log_file=None):
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)
if log_file is not None:
# File handler
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
logging.getLogger().addHandler(file_handler)
def format_big_number(num, precision=0):
suffixes = ["", "K", "M", "B", "T", "Q"]
@@ -162,16 +155,11 @@ def _relative_path_between(path1: Path, path2: Path) -> Path:
except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts
return Path(
"/".join(
[".."] * (len(path2.parts) - len(common_parts))
+ list(path1.parts[len(common_parts) :])
)
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
)
def init_hydra_config(
config_path: str, overrides: list[str] | None = None
) -> DictConfig:
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
"""Initialize a Hydra config given only the path to the relevant config file.
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
@@ -180,11 +168,7 @@ def init_hydra_config(
hydra.core.global_hydra.GlobalHydra.instance().clear()
# Hydra needs a path relative to this file.
hydra.initialize(
str(
_relative_path_between(
Path(config_path).absolute().parent, Path(__file__).absolute().parent
)
),
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)),
version_base="1.2",
)
cfg = hydra.compose(Path(config_path).stem, overrides)
@@ -198,26 +182,10 @@ def print_cuda_memory_usage():
gc.collect()
# Also clear the cache if you want to fully release the memory
torch.cuda.empty_cache()
print(
"Current GPU Memory Allocated: {:.2f} MB".format(
torch.cuda.memory_allocated(0) / 1024**2
)
)
print(
"Maximum GPU Memory Allocated: {:.2f} MB".format(
torch.cuda.max_memory_allocated(0) / 1024**2
)
)
print(
"Current GPU Memory Reserved: {:.2f} MB".format(
torch.cuda.memory_reserved(0) / 1024**2
)
)
print(
"Maximum GPU Memory Reserved: {:.2f} MB".format(
torch.cuda.max_memory_reserved(0) / 1024**2
)
)
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
def capture_timestamp_utc():
@@ -249,33 +217,3 @@ def log_say(text, play_sounds, blocking=False):
if play_sounds:
say(text, blocking)
class TimerManager:
def __init__(
self,
elapsed_time_list: list[float] | None = None,
label="Elapsed time",
log=True,
):
self.label = label
self.elapsed_time_list = elapsed_time_list
self.log = log
self.elapsed = 0.0
def __enter__(self):
self.start = time.perf_counter()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.elapsed: float = time.perf_counter() - self.start
if self.elapsed_time_list is not None:
self.elapsed_time_list.append(self.elapsed)
if self.log:
print(f"{self.label}: {self.elapsed:.6f} seconds")
@property
def elapsed_seconds(self):
return self.elapsed

View File

@@ -2,7 +2,6 @@ defaults:
- _self_
- env: pusht
- policy: diffusion
- robot: so100
hydra:
run:

View File

@@ -1,30 +0,0 @@
# @package _global_
fps: 400
env:
name: maniskill/pushcube
task: PushCube-v1
image_size: 64
control_mode: pd_ee_delta_pose
state_dim: 25
action_dim: 7
fps: ${fps}
obs: rgb
render_mode: rgb_array
render_size: 64
device: cuda
reward_classifier:
pretrained_path: null
config_path: null
wrapper:
joint_masking_action_space: null
delta_action: null
video_record:
enabled: false
record_dir: maniskill_videos
trajectory_name: trajectory
fps: ${fps}

View File

@@ -1,50 +1,10 @@
# @package _global_
fps: 10
fps: 30
env:
name: real_world
task: null
state_dim: 15
action_dim: 3
state_dim: 6
action_dim: 6
fps: ${fps}
device: mps
wrapper:
crop_params_dict:
observation.images.front: [171, 207, 116, 251]
observation.images.side: [232, 200, 142, 204]
resize_size: [128, 128]
control_time_s: 10
reset_follower_pos: false
use_relative_joint_positions: true
reset_time_s: 5
display_cameras: false
delta_action: null #0.3
joint_masking_action_space: null #[1, 1, 1, 1, 0, 0] # disable wrist and gripper
add_joint_velocity_to_observation: true
add_ee_pose_to_observation: true
# If null then the teleoperation will be used to reset the robot
# Bounds for pushcube_gamepad_lerobot15 dataset and experiments
# fixed_reset_joint_positions: [-19.86, 103.19, 117.33, 42.7, 13.89, 0.297]
# ee_action_space_params: # If null then ee_action_space is not used
# bounds:
# max: [0.291, 0.147, 0.074]
# min: [0.139, -0.143, 0.03]
# Bounds for insertcube_gamepad dataset and experiments
fixed_reset_joint_positions: [20.0, 90., 90., 75., -0.7910156, -0.5673759]
ee_action_space_params:
bounds:
max: [0.25295413, 0.07498981, 0.06862044]
min: [0.2010096, -0.12, 0.0433196]
use_gamepad: true
x_step_size: 0.03
y_step_size: 0.03
z_step_size: 0.03
reward_classifier:
pretrained_path: null # outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
config_path: null # lerobot/configs/policy/hilserl_classifier.yaml

View File

@@ -114,7 +114,7 @@ policy:
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_coeff: null
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1

View File

@@ -95,7 +95,7 @@ policy:
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_coeff: null
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1

View File

@@ -95,7 +95,7 @@ policy:
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_coeff: null
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1

View File

@@ -95,7 +95,7 @@ policy:
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_coeff: null
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1

View File

@@ -1,61 +0,0 @@
# @package _global_
defaults:
- _self_
hydra:
run:
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
dir: outputs/train_hilserl_classifier/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${hydra.job.name}
job:
name: default
seed: 13
dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized
# aractingi/push_cube_square_reward_1_cropped_resized
dataset_root: data/aractingi/push_cube_square_light_reward_cropped_resized
local_files_only: true
train_split_proportion: 0.8
# Required by logger
env:
name: "classifier"
task: "binary_classification"
training:
num_epochs: 6
batch_size: 16
learning_rate: 1e-4
num_workers: 4
grad_clip_norm: 10
use_amp: true
log_freq: 1
eval_freq: 1 # How often to run validation (in epochs)
save_freq: 1 # How often to save checkpoints (in epochs)
save_checkpoint: true
image_keys: ["observation.images.front", "observation.images.side"]
label_key: "next.reward"
profile_inference_time: false
profile_inference_time_iters: 20
eval:
batch_size: 16
num_samples_to_log: 30 # Number of validation samples to log in the table
policy:
name: "hilserl/classifier"
model_name: "helper2424/resnet10" # "facebook/convnext-base-224
model_type: "cnn"
num_cameras: 2 # Has to be len(training.image_keys)
wandb:
enable: false
project: "classifier-training"
job_name: "classifier_training_0"
disable_artifact: false
device: "mps"
resume: false
output_dir: "outputs/classifier/old_trainer_resnet10_frozen"

View File

@@ -1,118 +0,0 @@
# @package _global_
# Train with:
#
# python lerobot/scripts/train.py \
# +dataset=lerobot/pusht_keypoints
# env=pusht \
# env.gym.obs_type=environment_state_agent_pos \
seed: 1
# dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium"
dataset_repo_id: null
training:
# Offline training dataloader
num_workers: 4
batch_size: 512
grad_clip_norm: 40.0
lr: 3e-4
storage_device: "cuda"
eval_freq: 2500
log_freq: 10
save_freq: 1000000
online_steps: 1000000
online_rollout_n_episodes: 10
online_rollout_batch_size: 10
online_steps_between_rollouts: 1000
online_sampling_ratio: 1.0
online_env_seed: 10000
online_buffer_capacity: 200000
offline_buffer_capacity: 100000
online_buffer_seed_size: 0
online_step_before_learning: 500
do_online_rollout_async: false
policy_update_freq: 1
policy:
name: sac
pretrained_model_path:
# Input / output structure.
n_action_repeats: 1
horizon: 1
n_action_steps: 1
shared_encoder: true
# vision_encoder_name: "helper2424/resnet10"
vision_encoder_name: null
# freeze_vision_encoder: true
freeze_vision_encoder: false
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]
observation.image: [3, 64, 64]
output_shapes:
action: [7]
camera_number: 1
# Normalization / Unnormalization
# input_normalization_modes: null
input_normalization_modes:
observation.state: min_max
observation.image: mean_std
# input_normalization_params: null
input_normalization_params:
observation.state:
min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
-3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
-6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
0.4001]
observation.image:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
output_normalization_modes:
action: min_max
output_normalization_params:
action:
min: [-0.03, -0.03, -0.03, -0.03, -0.03, -0.03, -0.03]
max: [0.03, 0.03, 0.03, 0.03, 0.03, 0.03, 0.03]
output_normalization_shapes:
action: [7]
# Architecture / modeling.
# Neural networks.
image_encoder_hidden_dim: 32
# discount: 0.99
discount: 0.80
temperature_init: 1.0
num_critics: 2 #10
num_subsample_critics: null
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4
# critic_target_update_weight: 0.005
critic_target_update_weight: 0.01
utd_ratio: 2 # 10
actor_learner_config:
learner_host: "127.0.0.1"
learner_port: 50051
policy_parameters_push_frequency: 4
concurrency:
actor: 'threads'
learner: 'threads'

View File

@@ -1,89 +0,0 @@
# @package _global_
# Train with:
#
# python lerobot/scripts/train.py \
# env=pusht \
# +dataset=lerobot/pusht_keypoints
seed: 1
dataset_repo_id: lerobot/pusht_keypoints
training:
offline_steps: 0
# Offline training dataloader
num_workers: 4
batch_size: 128
grad_clip_norm: 10.0
lr: 3e-4
eval_freq: 50000
log_freq: 500
save_freq: 50000
online_steps: 1000000
online_rollout_n_episodes: 10
online_rollout_batch_size: 10
online_steps_between_rollouts: 1000
online_sampling_ratio: 1.0
online_env_seed: 10000
online_buffer_capacity: 40000
online_buffer_seed_size: 0
do_online_rollout_async: false
delta_timestamps:
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
action: "[i / ${fps} for i in range(${policy.horizon})]"
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
policy:
name: sac
pretrained_model_path:
# Input / output structure.
n_action_repeats: 1
horizon: 5
n_action_steps: 5
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.environment_state: [16]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.environment_state: min_max
observation.state: min_max
output_normalization_modes:
action: min_max
# Architecture / modeling.
# Neural networks.
# image_encoder_hidden_dim: 32
discount: 0.99
temperature_init: 1.0
num_critics: 2
num_subsample_critics: None
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4
critic_target_update_weight: 0.005
utd_ratio: 2
# # Loss coefficients.
# reward_coeff: 0.5
# expectile_weight: 0.9
# value_coeff: 0.1
# consistency_coeff: 20.0
# advantage_scaling: 3.0
# pi_coeff: 0.5
# temporal_decay_coeff: 0.5
# # Target model.
# target_model_momentum: 0.995

View File

@@ -1,120 +0,0 @@
# @package _global_
# Train with:
#
# python lerobot/scripts/train.py \
# +dataset=lerobot/pusht_keypoints
# env=pusht \
# env.gym.obs_type=environment_state_agent_pos \
seed: 1
dataset_repo_id: aractingi/insertcube_simple
training:
# Offline training dataloader
num_workers: 4
# batch_size: 256
batch_size: 512
grad_clip_norm: 10.0
lr: 3e-4
eval_freq: 2500
log_freq: 1
save_freq: 2000000
online_steps: 1000000
online_rollout_n_episodes: 10
online_rollout_batch_size: 10
online_steps_between_rollouts: 1000
online_sampling_ratio: 1.0
online_env_seed: 10000
online_buffer_capacity: 10000
online_buffer_seed_size: 0
online_step_before_learning: 100 #5000
do_online_rollout_async: false
policy_update_freq: 1
# delta_timestamps:
# observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
# observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
# action: "[i / ${fps} for i in range(${policy.horizon})]"
# next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
policy:
name: sac
pretrained_model_path:
# Input / output structure.
n_action_repeats: 1
horizon: 1
n_action_steps: 1
shared_encoder: true
vision_encoder_name: "helper2424/resnet10"
freeze_vision_encoder: true
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]
observation.images.front: [3, 128, 128]
observation.images.side: [3, 128, 128]
# observation.image: [3, 128, 128]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.front: mean_std
observation.images.side: mean_std
observation.state: min_max
input_normalization_params:
observation.images.front:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
observation.images.side:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
observation.state:
# 6- joint positions, 6- joint velocities, 3- ee position
max: [ 52.822266, 136.14258, 142.03125, 72.1582, 22.675781, -0.5673759, 100., 100., 100., 100., 100., 100., 0.25295413, 0.07498981, 0.06862044]
min: [-2.6367188, 86.572266, 89.82422, 12.392578, -26.015625, -0.5673759, -100., -100., -100., -100., -100., -100., 0.2010096, -0.12, 0.0433196]
output_normalization_modes:
action: min_max
output_normalization_params:
action:
min: [-0.03, -0.03, -0.01]
max: [0.03, 0.03, 0.03]
# Architecture / modeling.
# Neural networks.
image_encoder_hidden_dim: 32
# discount: 0.99
discount: 0.97
temperature_init: 1.0
num_critics: 2 #10
camera_number: 2
num_subsample_critics: null
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4
# critic_target_update_weight: 0.005
critic_target_update_weight: 0.01
utd_ratio: 2 # 10
actor_learner_config:
learner_host: "127.0.0.1"
learner_port: 50051
policy_parameters_push_frequency: 15
# # Loss coefficients.
# reward_coeff: 0.5
# expectile_weight: 0.9
# value_coeff: 0.1
# consistency_coeff: 20.0
# advantage_scaling: 3.0
# pi_coeff: 0.5
# temporal_decay_coeff: 0.5
# # Target model.
# target_model_momentum: 0.995

View File

@@ -10,7 +10,7 @@ max_relative_target: null
leader_arms:
main:
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
port: /dev/tty.usbmodem58760430441
port: /dev/tty.usbmodem575E0031751
motors:
# name: (index, model)
shoulder_pan: [1, "xl330-m077"]
@@ -23,7 +23,7 @@ leader_arms:
follower_arms:
main:
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
port: /dev/tty.usbmodem585A0083391
port: /dev/tty.usbmodem575E0032081
motors:
# name: (index, model)
shoulder_pan: [1, "xl430-w250"]

Some files were not shown because too many files have changed in this diff Show More