Compare commits
7 Commits
user/miche
...
tdmpc23
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
14490148f3 | ||
|
|
16edbbdeee | ||
|
|
15090c2544 | ||
|
|
166c1fc776 | ||
|
|
31984645da | ||
|
|
c41ec08ec1 | ||
|
|
a146544765 |
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -21,7 +21,7 @@ Provide a simple way for the reviewer to try out your changes.
|
||||
|
||||
Examples:
|
||||
```bash
|
||||
pytest -sx tests/test_stuff.py::test_something
|
||||
DATA_DIR=tests/data pytest -sx tests/test_stuff.py::test_something
|
||||
```
|
||||
```bash
|
||||
python lerobot/scripts/train.py --some.option=true
|
||||
|
||||
8
.github/workflows/nightly-tests.yml
vendored
8
.github/workflows/nightly-tests.yml
vendored
@@ -7,8 +7,10 @@ on:
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
|
||||
# env:
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
|
||||
|
||||
jobs:
|
||||
run_all_tests_cpu:
|
||||
name: CPU
|
||||
@@ -28,9 +30,13 @@ jobs:
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Tests
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
run: pytest -v --cov=./lerobot --disable-warnings tests
|
||||
|
||||
- name: Tests end-to-end
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
run: make test-end-to-end
|
||||
|
||||
|
||||
|
||||
4
.github/workflows/quality.yml
vendored
4
.github/workflows/quality.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install "poetry<2.0.0"
|
||||
run: pipx install poetry
|
||||
|
||||
- name: Poetry check
|
||||
run: poetry check
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install "poetry<2.0.0"
|
||||
run: pipx install poetry
|
||||
|
||||
- name: Install poetry-relax
|
||||
run: poetry self add poetry-relax
|
||||
|
||||
65
.github/workflows/test.yml
vendored
65
.github/workflows/test.yml
vendored
@@ -29,6 +29,7 @@ jobs:
|
||||
name: Pytest
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -69,6 +70,7 @@ jobs:
|
||||
name: Pytest (minimal install)
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -101,39 +103,40 @@ jobs:
|
||||
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
|
||||
&& rm -rf tests/outputs outputs
|
||||
|
||||
# TODO(aliberts, rcadene): redesign after v2 migration / removing hydra
|
||||
# end-to-end:
|
||||
# name: End-to-end
|
||||
# runs-on: ubuntu-latest
|
||||
# env:
|
||||
# MUJOCO_GL: egl
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# lfs: true # Ensure LFS files are pulled
|
||||
|
||||
# - name: Install apt dependencies
|
||||
# # portaudio19-dev is needed to install pyaudio
|
||||
# run: |
|
||||
# sudo apt-get update && \
|
||||
# sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
||||
end-to-end:
|
||||
name: End-to-end
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
|
||||
# - name: Install poetry
|
||||
# run: |
|
||||
# pipx install poetry && poetry config virtualenvs.in-project true
|
||||
# echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
- name: Install apt dependencies
|
||||
# portaudio19-dev is needed to install pyaudio
|
||||
run: |
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
||||
|
||||
# - name: Set up Python 3.10
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: "3.10"
|
||||
# cache: "poetry"
|
||||
- name: Install poetry
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
# - name: Install poetry dependencies
|
||||
# run: |
|
||||
# poetry install --all-extras
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: "poetry"
|
||||
|
||||
# - name: Test end-to-end
|
||||
# run: |
|
||||
# make test-end-to-end \
|
||||
# && rm -rf outputs
|
||||
- name: Install poetry dependencies
|
||||
run: |
|
||||
poetry install --all-extras
|
||||
|
||||
- name: Test end-to-end
|
||||
run: |
|
||||
make test-end-to-end \
|
||||
&& rm -rf outputs
|
||||
|
||||
@@ -3,7 +3,7 @@ default_language_version:
|
||||
python: python3.10
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
rev: v4.6.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
- id: debug-statements
|
||||
@@ -14,12 +14,11 @@ repos:
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.19.0
|
||||
rev: v3.16.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.8.2
|
||||
rev: v0.5.2
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
@@ -33,6 +32,6 @@ repos:
|
||||
- "--check"
|
||||
- "--no-update"
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.21.2
|
||||
rev: v8.18.4
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
@@ -267,7 +267,7 @@ We use `pytest` in order to run the tests. From the root of the
|
||||
repository, here's how to run tests with `pytest` for the library:
|
||||
|
||||
```bash
|
||||
python -m pytest -sv ./tests
|
||||
DATA_DIR="tests/data" python -m pytest -sv ./tests
|
||||
```
|
||||
|
||||
|
||||
|
||||
16
README.md
16
README.md
@@ -68,7 +68,7 @@
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
||||
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
||||
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
||||
@@ -153,12 +153,10 @@ python lerobot/scripts/visualize_dataset.py \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
or from a dataset in a local folder with the root `DATA_DIR` environment variable (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
DATA_DIR='./my_local_data_dir' python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--root ./my_local_data_dir \
|
||||
--local-files-only 1 \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
@@ -210,10 +208,12 @@ dataset attributes:
|
||||
|
||||
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
|
||||
- hf_dataset stored using Hugging Face datasets library serialization to parquet
|
||||
- videos are stored in mp4 format to save space
|
||||
- metadata are stored in plain json/jsonl files
|
||||
- videos are stored in mp4 format to save space or png files
|
||||
- episode_data_index saved using `safetensor` tensor serialization format
|
||||
- stats saved using `safetensor` tensor serialization format
|
||||
- info are saved using JSON
|
||||
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can use the `local_files_only` argument and specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can set the `DATA_DIR` environment variable to your root dataset folder as illustrated in the above section on dataset visualization.
|
||||
|
||||
### Evaluate a pretrained policy
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ How to decode videos?
|
||||
|
||||
## Variables
|
||||
**Image content & size**
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
@@ -63,7 +63,7 @@ This of course is affected by the `-g` parameter during encoding, which specifie
|
||||
|
||||
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
|
||||
|
||||
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
|
||||
Additionally, because some policies might request single timestamps that are a few frames appart, we also have the following scenario:
|
||||
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
|
||||
|
||||
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
|
||||
@@ -85,8 +85,8 @@ However, due to how video decoding is implemented with `pyav`, we don't have acc
|
||||
**Average Structural Similarity Index Measure (higher is better)**
|
||||
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
|
||||
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding accross platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not be pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
- `yuv420p` is more widely supported across various platforms, including web browsers.
|
||||
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
|
||||
|
||||
@@ -116,7 +116,7 @@ Additional encoding parameters exist that are not included in this benchmark. In
|
||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
||||
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
|
||||
|
||||
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
|
||||
See the documentation mentioned above for more detailled info on these settings and for a more comprehensive list of other parameters.
|
||||
|
||||
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
|
||||
- `torchaudio`
|
||||
|
||||
@@ -32,11 +32,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import PIL
|
||||
import torch
|
||||
from skimage.metrics import (
|
||||
mean_squared_error,
|
||||
peak_signal_noise_ratio,
|
||||
structural_similarity,
|
||||
)
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
@@ -85,9 +81,7 @@ def get_directory_size(directory: Path) -> int:
|
||||
return total_size
|
||||
|
||||
|
||||
def load_original_frames(
|
||||
imgs_dir: Path, timestamps: list[float], fps: int
|
||||
) -> torch.Tensor:
|
||||
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
|
||||
frames = []
|
||||
for ts in timestamps:
|
||||
idx = int(ts * fps)
|
||||
@@ -100,11 +94,7 @@ def load_original_frames(
|
||||
|
||||
|
||||
def save_decoded_frames(
|
||||
imgs_dir: Path,
|
||||
save_dir: Path,
|
||||
frames: torch.Tensor,
|
||||
timestamps: list[float],
|
||||
fps: int,
|
||||
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
|
||||
) -> None:
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
|
||||
return
|
||||
@@ -114,10 +104,7 @@ def save_decoded_frames(
|
||||
idx = int(ts * fps)
|
||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
|
||||
shutil.copyfile(
|
||||
imgs_dir / f"frame_{idx:06d}.png",
|
||||
save_dir / f"frame_{idx:06d}_original.png",
|
||||
)
|
||||
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
|
||||
|
||||
|
||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
@@ -129,17 +116,11 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# We only save images from the first camera
|
||||
img_keys = [
|
||||
key for key in hf_dataset.features if key.startswith("observation.image")
|
||||
]
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
|
||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
||||
|
||||
for i, item in enumerate(
|
||||
tqdm(
|
||||
imgs_dataset,
|
||||
desc=f"saving {dataset.repo_id} first episode images",
|
||||
leave=False,
|
||||
)
|
||||
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
|
||||
):
|
||||
img = item[img_keys[0]]
|
||||
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
|
||||
@@ -148,9 +129,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
break
|
||||
|
||||
|
||||
def sample_timestamps(
|
||||
timestamps_mode: str, ep_num_images: int, fps: int
|
||||
) -> list[float]:
|
||||
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
|
||||
# Start at 5 to allow for 2_frames_4_space and 6_frames
|
||||
idx = random.randint(5, ep_num_images - 1)
|
||||
match timestamps_mode:
|
||||
@@ -175,9 +154,7 @@ def decode_video_frames(
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend
|
||||
)
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
|
||||
@@ -204,9 +181,7 @@ def benchmark_decoding(
|
||||
}
|
||||
|
||||
with time_benchmark:
|
||||
frames = decode_video_frames(
|
||||
video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend
|
||||
)
|
||||
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
|
||||
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
|
||||
|
||||
with time_benchmark:
|
||||
@@ -215,18 +190,12 @@ def benchmark_decoding(
|
||||
|
||||
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
|
||||
for i in range(num_frames):
|
||||
result["mse_values"].append(
|
||||
mean_squared_error(original_frames_np[i], frames_np[i])
|
||||
)
|
||||
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
|
||||
result["psnr_values"].append(
|
||||
peak_signal_noise_ratio(
|
||||
original_frames_np[i], frames_np[i], data_range=1.0
|
||||
)
|
||||
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
|
||||
)
|
||||
result["ssim_values"].append(
|
||||
structural_similarity(
|
||||
original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0
|
||||
)
|
||||
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
|
||||
)
|
||||
|
||||
if save_frames and sample == 0:
|
||||
@@ -246,9 +215,7 @@ def benchmark_decoding(
|
||||
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(process_sample, i) for i in range(num_samples)]
|
||||
for future in tqdm(
|
||||
as_completed(futures), total=num_samples, desc="samples", leave=False
|
||||
):
|
||||
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
|
||||
result = future.result()
|
||||
load_times_video_ms.append(result["load_time_video_ms"])
|
||||
load_times_images_ms.append(result["load_time_images_ms"])
|
||||
@@ -299,7 +266,7 @@ def benchmark_encoding_decoding(
|
||||
)
|
||||
|
||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
||||
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
|
||||
width, height = tuple(dataset[0][dataset.camera_keys[0]].shape[-2:])
|
||||
num_pixels = width * height
|
||||
video_size_bytes = video_path.stat().st_size
|
||||
images_size_bytes = get_directory_size(imgs_dir)
|
||||
@@ -308,13 +275,9 @@ def benchmark_encoding_decoding(
|
||||
random.seed(seed)
|
||||
benchmark_table = []
|
||||
for timestamps_mode in tqdm(
|
||||
decoding_cfg["timestamps_modes"],
|
||||
desc="decodings (timestamps_modes)",
|
||||
leave=False,
|
||||
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
|
||||
):
|
||||
for backend in tqdm(
|
||||
decoding_cfg["backends"], desc="decodings (backends)", leave=False
|
||||
):
|
||||
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
|
||||
benchmark_row = benchmark_decoding(
|
||||
imgs_dir,
|
||||
video_path,
|
||||
@@ -392,23 +355,14 @@ def main(
|
||||
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
|
||||
# We only use the first episode
|
||||
save_first_episode(imgs_dir, dataset)
|
||||
for key, values in tqdm(
|
||||
encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False
|
||||
):
|
||||
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
|
||||
for value in tqdm(values, desc=f"encodings ({key})", leave=False):
|
||||
encoding_cfg = BASE_ENCODING.copy()
|
||||
encoding_cfg["vcodec"] = video_codec
|
||||
encoding_cfg["pix_fmt"] = pixel_format
|
||||
encoding_cfg[key] = value
|
||||
args_path = Path(
|
||||
"_".join(str(value) for value in encoding_cfg.values())
|
||||
)
|
||||
video_path = (
|
||||
output_dir
|
||||
/ "videos"
|
||||
/ args_path
|
||||
/ f"{repo_id.replace('/', '_')}.mp4"
|
||||
)
|
||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
||||
benchmark_table += benchmark_encoding_decoding(
|
||||
dataset,
|
||||
video_path,
|
||||
@@ -434,9 +388,7 @@ def main(
|
||||
# Concatenate all results
|
||||
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
|
||||
concatenated_df = pd.concat(df_list, ignore_index=True)
|
||||
concatenated_path = (
|
||||
output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
|
||||
)
|
||||
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
|
||||
concatenated_df.to_csv(concatenated_path, header=True, index=False)
|
||||
|
||||
|
||||
|
||||
18
checkport.py
18
checkport.py
@@ -1,18 +0,0 @@
|
||||
import socket
|
||||
|
||||
|
||||
def check_port(host, port):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
s.connect((host, port))
|
||||
print(f"Connection successful to {host}:{port}!")
|
||||
except Exception as e:
|
||||
print(f"Connection failed to {host}:{port}: {e}")
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
host = "127.0.0.1" # or "localhost"
|
||||
port = 51350
|
||||
check_port(host, port)
|
||||
@@ -1,11 +0,0 @@
|
||||
FROM huggingface/lerobot-gpu:latest
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libvulkan1 vulkan-tools \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[mani-skill]"
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
@@ -1,31 +1,25 @@
|
||||
# Using the [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot
|
||||
This tutorial explains how to use [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot.
|
||||
|
||||
|
||||
## A. Source the parts
|
||||
## Source the parts
|
||||
|
||||
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
|
||||
|
||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## B. Install LeRobot
|
||||
## Install LeRobot
|
||||
|
||||
On your computer:
|
||||
|
||||
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
|
||||
```bash
|
||||
mkdir -p ~/miniconda3
|
||||
# Linux:
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
# Mac M-series:
|
||||
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
|
||||
# Mac Intel:
|
||||
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
~/miniconda3/bin/conda init bash
|
||||
```
|
||||
|
||||
2. Restart shell or `source ~/.bashrc` (*Mac*: `source ~/.bash_profile`) or `source ~/.zshrc` if you're using zshell
|
||||
2. Restart shell or `source ~/.bashrc`
|
||||
|
||||
3. Create and activate a fresh conda environment for lerobot
|
||||
```bash
|
||||
@@ -42,30 +36,23 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
*For Linux only (not Mac)*: install extra dependencies for recording datasets:
|
||||
For Linux only (not Mac), install extra dependencies for recording datasets:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
## C. Configure the motors
|
||||
## Configure the motors
|
||||
|
||||
### 1. Find the USB ports associated to each arm
|
||||
Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the use of our scripts below.
|
||||
|
||||
Designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm.
|
||||
|
||||
#### a. Run the script to find ports
|
||||
|
||||
Follow Step 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I), which illustrates the use of our scripts below.
|
||||
|
||||
To find the port for each bus servo adapter, run the utility script:
|
||||
**Find USB ports associated to your arms**
|
||||
To find the correct ports for each arm, run the utility script twice:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
#### b. Example outputs
|
||||
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
@@ -77,6 +64,7 @@ Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
@@ -89,20 +77,13 @@ The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
#### c. Troubleshooting
|
||||
On Linux, you might need to give access to the USB ports by running:
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
#### d. Update YAML file
|
||||
|
||||
Now that you have the ports, modify the *port* sections in `so100.yaml`
|
||||
|
||||
### 2. Configure the motors
|
||||
|
||||
#### a. Set IDs for all 12 motors
|
||||
**Configure your motors**
|
||||
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
@@ -113,7 +94,7 @@ python lerobot/scripts/configure_motor.py \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
*Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).*
|
||||
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
@@ -127,25 +108,23 @@ python lerobot/scripts/configure_motor.py \
|
||||
|
||||
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
|
||||
|
||||
**Remove the gears of the 6 leader motors**
|
||||
Follow step 2 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
#### b. Remove the gears of the 6 leader motors
|
||||
|
||||
Follow step 2 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=248). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
#### c. Add motor horn to all 12 motors
|
||||
Follow step 3 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=569). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
||||
**Add motor horn to the motors**
|
||||
Follow step 3 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
||||
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
|
||||
|
||||
## D. Assemble the arms
|
||||
## Assemble the arms
|
||||
|
||||
Follow step 4 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=610). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
|
||||
## E. Calibrate
|
||||
## Calibrate
|
||||
|
||||
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
|
||||
|
||||
#### a. Manual calibration of follower arm
|
||||
/!\ Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
**Manual calibration of follower arm**
|
||||
/!\ Contrarily to step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
@@ -160,8 +139,8 @@ python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-overrides '~cameras' --arms main_follower
|
||||
```
|
||||
|
||||
#### b. Manual calibration of leader arm
|
||||
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
**Manual calibration of leader arm**
|
||||
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
@@ -174,7 +153,7 @@ python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-overrides '~cameras' --arms main_leader
|
||||
```
|
||||
|
||||
## F. Teleoperate
|
||||
## Teleoperate
|
||||
|
||||
**Simple teleop**
|
||||
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
|
||||
@@ -186,14 +165,14 @@ python lerobot/scripts/control_robot.py teleoperate \
|
||||
```
|
||||
|
||||
|
||||
#### a. Teleop with displaying cameras
|
||||
**Teleop with displaying cameras**
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml
|
||||
```
|
||||
|
||||
## G. Record a dataset
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
|
||||
|
||||
@@ -213,6 +192,7 @@ Record 2 episodes and upload your dataset to the hub:
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/so100_test \
|
||||
--tags so100 tutorial \
|
||||
--warmup-time-s 5 \
|
||||
@@ -222,7 +202,7 @@ python lerobot/scripts/control_robot.py record \
|
||||
--push-to-hub 1
|
||||
```
|
||||
|
||||
## H. Visualize a dataset
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
@@ -232,25 +212,27 @@ echo ${HF_USER}/so100_test
|
||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/so100_test
|
||||
```
|
||||
|
||||
## I. Replay an episode
|
||||
## Replay an episode
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
DATA_DIR=data python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/so100_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
## J. Train a policy
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
dataset_repo_id=${HF_USER}/so100_test \
|
||||
policy=act_so100_real \
|
||||
env=so100_real \
|
||||
@@ -266,16 +248,18 @@ Let's explain it:
|
||||
3. We provided an environment as argument with `env=so100_real`. This loads configurations from [`lerobot/configs/env/so100_real.yaml`](../lerobot/configs/env/so100_real.yaml).
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
||||
|
||||
## K. Evaluate your policy
|
||||
## Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_act_so100_test \
|
||||
--tags so100 tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
@@ -289,7 +273,7 @@ As you can see, it's almost the same command as previously used to record your t
|
||||
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_so100_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_so100_test`).
|
||||
|
||||
## L. More Information
|
||||
## More
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
||||
|
||||
|
||||
@@ -192,6 +192,7 @@ Record 2 episodes and upload your dataset to the hub:
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/moss_test \
|
||||
--tags moss tutorial \
|
||||
--warmup-time-s 5 \
|
||||
@@ -211,6 +212,7 @@ echo ${HF_USER}/moss_test
|
||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/moss_test
|
||||
```
|
||||
|
||||
@@ -218,9 +220,10 @@ python lerobot/scripts/visualize_dataset_html.py \
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
DATA_DIR=data python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/moss_test \
|
||||
--episode 0
|
||||
```
|
||||
@@ -229,7 +232,7 @@ python lerobot/scripts/control_robot.py replay \
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
dataset_repo_id=${HF_USER}/moss_test \
|
||||
policy=act_moss_real \
|
||||
env=moss_real \
|
||||
@@ -245,6 +248,7 @@ Let's explain it:
|
||||
3. We provided an environment as argument with `env=moss_real`. This loads configurations from [`lerobot/configs/env/moss_real.yaml`](../lerobot/configs/env/moss_real.yaml).
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
|
||||
|
||||
@@ -255,6 +259,7 @@ You can use the `record` function from [`lerobot/scripts/control_robot.py`](../l
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_act_moss_test \
|
||||
--tags moss tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
# Training a HIL-SERL Reward Classifier with LeRobot
|
||||
|
||||
This tutorial provides step-by-step instructions for training a reward classifier using LeRobot.
|
||||
|
||||
---
|
||||
|
||||
## Training Script Overview
|
||||
|
||||
LeRobot includes a ready-to-use training script located at [`lerobot/scripts/train_hilserl_classifier.py`](../../lerobot/scripts/train_hilserl_classifier.py). Here's an outline of its workflow:
|
||||
|
||||
1. **Configuration Loading**
|
||||
The script uses Hydra to load a configuration file for subsequent steps. (Details on Hydra follow below.)
|
||||
|
||||
2. **Dataset Initialization**
|
||||
It loads a `LeRobotDataset` containing images and rewards. To optimize performance, a weighted random sampler is used to balance class sampling.
|
||||
|
||||
3. **Classifier Initialization**
|
||||
A lightweight classification head is built on top of a frozen, pretrained image encoder from HuggingFace. The classifier outputs either:
|
||||
- A single probability (binary classification), or
|
||||
- Logits (multi-class classification).
|
||||
|
||||
4. **Training Loop Execution**
|
||||
The script performs:
|
||||
- Forward and backward passes,
|
||||
- Optimization steps,
|
||||
- Periodic logging, evaluation, and checkpoint saving.
|
||||
|
||||
---
|
||||
|
||||
## Configuring with Hydra
|
||||
|
||||
For detailed information about Hydra usage, refer to [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md). However, note that training the reward classifier differs slightly and requires a separate configuration file.
|
||||
|
||||
### Config File Setup
|
||||
|
||||
The default `default.yaml` cannot launch the reward classifier training directly. Instead, you need a configuration file like [`lerobot/configs/policy/hilserl_classifier.yaml`](../../lerobot/configs/policy/hilserl_classifier.yaml), with the following adjustment:
|
||||
|
||||
Replace the `dataset_repo_id` field with the identifier for your dataset, which contains images and sparse rewards:
|
||||
|
||||
```yaml
|
||||
# Example: lerobot/configs/policy/reward_classifier.yaml
|
||||
dataset_repo_id: "my_dataset_repo_id"
|
||||
## Typical logs and metrics
|
||||
```
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
|
||||
|
||||
After that, you will see training log like this one:
|
||||
|
||||
```
|
||||
[2024-11-29 18:26:36,999][root][INFO] -
|
||||
Epoch 5/5
|
||||
Training: 82%|██████████████████████████████████████████████████████████████████████████████▋ | 91/111 [00:50<00:09, 2.04it/s, loss=0.2999, acc=69.99%]
|
||||
```
|
||||
|
||||
or evaluation log like:
|
||||
|
||||
```
|
||||
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:20<00:00, 1.37it/s]
|
||||
```
|
||||
|
||||
### Metrics Tracking with Weights & Biases (WandB)
|
||||
|
||||
If `wandb.enable` is set to `true`, the training and evaluation logs will also be saved in WandB. This allows you to track key metrics in real-time, including:
|
||||
|
||||
- **Training Metrics**:
|
||||
- `train/accuracy`
|
||||
- `train/loss`
|
||||
- `train/dataloading_s`
|
||||
- **Evaluation Metrics**:
|
||||
- `eval/accuracy`
|
||||
- `eval/loss`
|
||||
- `eval/eval_s`
|
||||
|
||||
#### Additional Features
|
||||
|
||||
You can also log sample predictions during evaluation. Each logged sample will include:
|
||||
|
||||
- The **input image**.
|
||||
- The **predicted label**.
|
||||
- The **true label**.
|
||||
- The **classifier's "confidence" (logits/probability)**.
|
||||
|
||||
These logs can be useful for diagnosing and debugging performance issues.
|
||||
|
||||
|
||||
#### Generate protobuf files
|
||||
|
||||
```bash
|
||||
python -m grpc_tools.protoc \
|
||||
-I lerobot/scripts/server \
|
||||
--python_out=lerobot/scripts/server \
|
||||
--grpc_python_out=lerobot/scripts/server \
|
||||
lerobot/scripts/server/hilserl.proto
|
||||
```
|
||||
@@ -3,128 +3,78 @@ This script demonstrates the use of `LeRobotDataset` class for handling and proc
|
||||
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
||||
|
||||
Features included in this script:
|
||||
- Viewing a dataset's metadata and exploring its properties.
|
||||
- Loading an existing dataset from the hub or a subset of it.
|
||||
- Accessing frames by episode number.
|
||||
- Loading a dataset and accessing its properties.
|
||||
- Filtering data by episode number.
|
||||
- Converting tensor data for visualization.
|
||||
- Saving video files from dataset frames.
|
||||
- Using advanced dataset features like timestamp-based frame selection.
|
||||
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
|
||||
|
||||
The script ends with examples of how to batch process data using PyTorch's DataLoader.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
import imageio
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
)
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# We ported a number of existing datasets ourselves, use this to see the list:
|
||||
print("List of available datasets:")
|
||||
pprint(lerobot.available_datasets)
|
||||
|
||||
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
|
||||
hub_api = HfApi()
|
||||
repo_ids = [
|
||||
info.id
|
||||
for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])
|
||||
]
|
||||
pprint(repo_ids)
|
||||
# Let's take one for this example
|
||||
repo_id = "lerobot/pusht"
|
||||
|
||||
# Or simply explore them in your web browser directly at:
|
||||
# https://huggingface.co/datasets?other=LeRobot
|
||||
|
||||
# Let's take this one for this example
|
||||
repo_id = "lerobot/aloha_mobile_cabinet"
|
||||
# We can have a look and fetch its metadata to know more about it:
|
||||
ds_meta = LeRobotDatasetMetadata(repo_id)
|
||||
|
||||
# By instantiating just this class, you can quickly access useful information about the content and the
|
||||
# structure of the dataset without downloading the actual data yet (only metadata files — which are
|
||||
# lightweight).
|
||||
print(f"Total number of episodes: {ds_meta.total_episodes}")
|
||||
print(
|
||||
f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}"
|
||||
)
|
||||
print(f"Frames per second used during data collection: {ds_meta.fps}")
|
||||
print(f"Robot type: {ds_meta.robot_type}")
|
||||
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
|
||||
|
||||
print("Tasks:")
|
||||
print(ds_meta.tasks)
|
||||
print("Features:")
|
||||
pprint(ds_meta.features)
|
||||
|
||||
# You can also get a short summary by simply printing the object:
|
||||
print(ds_meta)
|
||||
|
||||
# You can then load the actual dataset from the hub.
|
||||
# Either load any subset of episodes:
|
||||
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
|
||||
|
||||
# And see how many frames you have:
|
||||
print(f"Selected episodes: {dataset.episodes}")
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
|
||||
# Or simply load the entire dataset:
|
||||
# You can easily load a dataset from a Hugging Face repository
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
|
||||
# The previous metadata class is contained in the 'meta' attribute of the dataset:
|
||||
print(dataset.meta)
|
||||
|
||||
# LeRobotDataset actually wraps an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets for more information).
|
||||
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets/index for more information).
|
||||
print(dataset)
|
||||
print(dataset.hf_dataset)
|
||||
|
||||
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset.
|
||||
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
|
||||
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
|
||||
# frame indices associated to the first episode:
|
||||
# And provides additional utilities for robotics and compatibility with Pytorch
|
||||
print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
|
||||
print(f"frames per second used during data collection: {dataset.fps=}")
|
||||
print(f"keys to access images from cameras: {dataset.camera_keys=}\n")
|
||||
|
||||
# Access frame indexes associated to first episode
|
||||
episode_index = 0
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
|
||||
# Then we grab all the image frames from the first camera:
|
||||
camera_key = dataset.meta.camera_keys[0]
|
||||
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
|
||||
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset. Here we grab all the image frames.
|
||||
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
|
||||
|
||||
# The objects returned by the dataset are all torch.Tensors
|
||||
print(type(frames[0]))
|
||||
print(frames[0].shape)
|
||||
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize
|
||||
# them, we convert to uint8 in range [0,255]
|
||||
frames = [(frame * 255).type(torch.uint8) for frame in frames]
|
||||
# and to channel last (h,w,c).
|
||||
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
||||
|
||||
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
|
||||
# We can compare this shape with the information available for that feature
|
||||
pprint(dataset.features[camera_key])
|
||||
# In particular:
|
||||
print(dataset.features[camera_key]["shape"])
|
||||
# The shape is in (h, w, c) which is a more universal format.
|
||||
# Finally, we save the frames to a mp4 video for visualization.
|
||||
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
||||
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
|
||||
|
||||
# For many machine learning applications we need to load the history of past observations or trajectories of
|
||||
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
|
||||
# differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
camera_key: [-1, -0.5, -0.20, 0],
|
||||
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||
"observation.image": [-1, -0.5, -0.20, 0],
|
||||
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
|
||||
# timestamp, you still get a valid timestamp.
|
||||
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||
print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64,c)
|
||||
|
||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
|
||||
# PyTorch datasets.
|
||||
@@ -134,9 +84,8 @@ dataloader = torch.utils.data.DataLoader(
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 5, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32,8,c)
|
||||
print(f"{batch['action'].shape=}") # (32,64,c)
|
||||
break
|
||||
|
||||
@@ -32,9 +32,7 @@ if torch.cuda.is_available():
|
||||
print("GPU is available. Device set to:", device)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print(
|
||||
f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU."
|
||||
)
|
||||
print(f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU.")
|
||||
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
|
||||
policy.diffusion.num_inference_steps = 10
|
||||
|
||||
|
||||
@@ -31,24 +31,7 @@ delta_timestamps = {
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to supervise the policy.
|
||||
"action": [
|
||||
-0.1,
|
||||
0.0,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
0.7,
|
||||
0.8,
|
||||
0.9,
|
||||
1.0,
|
||||
1.1,
|
||||
1.2,
|
||||
1.3,
|
||||
1.4,
|
||||
],
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
|
||||
|
||||
@@ -57,7 +40,7 @@ dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
|
||||
# For this example, no arguments need to be passed because the defaults are set up for PushT.
|
||||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||
cfg = DiffusionConfig()
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
||||
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
||||
transforms are applied to the observation images before they are returned in the dataset's __getitem__.
|
||||
transforms are applied to the observation images before they are returned in the dataset's __get_item__.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -10,17 +10,17 @@ from torchvision.transforms import ToPILImage, v2
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset_repo_id = "lerobot/aloha_static_screw_driver"
|
||||
dataset_repo_id = "lerobot/aloha_static_tape"
|
||||
|
||||
# Create a LeRobotDataset with no transformations
|
||||
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
|
||||
dataset = LeRobotDataset(dataset_repo_id)
|
||||
# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`
|
||||
|
||||
# Get the index of the first observation in the first episode
|
||||
first_idx = dataset.episode_data_index["from"][0].item()
|
||||
|
||||
# Get the frame corresponding to the first camera
|
||||
frame = dataset[first_idx][dataset.meta.camera_keys[0]]
|
||||
frame = dataset[first_idx][dataset.camera_keys[0]]
|
||||
|
||||
|
||||
# Define the transformations
|
||||
@@ -28,20 +28,15 @@ transforms = v2.Compose(
|
||||
[
|
||||
v2.ColorJitter(brightness=(0.5, 1.5)),
|
||||
v2.ColorJitter(contrast=(0.5, 1.5)),
|
||||
v2.ColorJitter(hue=(-0.1, 0.1)),
|
||||
v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
|
||||
]
|
||||
)
|
||||
|
||||
# Create another LeRobotDataset with the defined transformations
|
||||
transformed_dataset = LeRobotDataset(
|
||||
dataset_repo_id, episodes=[0], image_transforms=transforms
|
||||
)
|
||||
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)
|
||||
|
||||
# Get a frame from the transformed dataset
|
||||
transformed_frame = transformed_dataset[first_idx][
|
||||
transformed_dataset.meta.camera_keys[0]
|
||||
]
|
||||
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
|
||||
|
||||
# Create a directory to store output images
|
||||
output_dir = Path("outputs/image_transforms")
|
||||
|
||||
@@ -29,7 +29,7 @@ For a visual walkthrough of the assembly process, you can refer to [this video t
|
||||
|
||||
## 2. Configure motors, calibrate arms, teleoperate your Koch v1.1
|
||||
|
||||
First, install the additional dependencies required for robots built with dynamixel motors like Koch v1.1 by running one of the following commands (make sure gcc is installed).
|
||||
First, install the additional dependencies required for robots built with dynamixel motors like Koch v1.1 by running one of the following commands.
|
||||
|
||||
Using `pip`:
|
||||
```bash
|
||||
@@ -778,6 +778,7 @@ Now run this to record 2 episodes:
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/koch_test \
|
||||
--tags tutorial \
|
||||
--warmup-time-s 5 \
|
||||
@@ -786,7 +787,7 @@ python lerobot/scripts/control_robot.py record \
|
||||
--num-episodes 2
|
||||
```
|
||||
|
||||
This will write your dataset locally to `~/.cache/huggingface/lerobot/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
|
||||
This will write your dataset locally to `{root}/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
|
||||
|
||||
You can look for other LeRobot datasets on the hub by searching for `LeRobot` tags: https://huggingface.co/datasets?other=LeRobot
|
||||
|
||||
@@ -839,6 +840,7 @@ In the coming months, we plan to release a foundational model for robotics. We a
|
||||
You can visualize your dataset by running:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/koch_test
|
||||
```
|
||||
|
||||
@@ -856,6 +858,7 @@ To replay the first episode of the dataset you just recorded, run the following
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/koch_test \
|
||||
--episode 0
|
||||
```
|
||||
@@ -868,7 +871,7 @@ Your robot should replicate movements similar to those you recorded. For example
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
dataset_repo_id=${HF_USER}/koch_test \
|
||||
policy=act_koch_real \
|
||||
env=koch_real \
|
||||
@@ -915,6 +918,7 @@ env:
|
||||
It should match your dataset (e.g. `fps: 30`) and your robot (e.g. `state_dim: 6` and `action_dim: 6`). We are still working on simplifying this in future versions of `lerobot`.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
||||
|
||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||
|
||||
@@ -987,6 +991,7 @@ To this end, you can use the `record` function from [`lerobot/scripts/control_ro
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_koch_test \
|
||||
--tags tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
@@ -1005,6 +1010,7 @@ As you can see, it's almost the same command as previously used to record your t
|
||||
You can then visualize your evaluation dataset by running the same command as before but with the new inference dataset as argument:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_koch_test
|
||||
```
|
||||
|
||||
|
||||
@@ -128,6 +128,7 @@ Record one episode:
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/stretch.yaml \
|
||||
--fps 20 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/stretch_test \
|
||||
--tags stretch tutorial \
|
||||
--warmup-time-s 3 \
|
||||
@@ -145,6 +146,7 @@ Now try to replay this episode (make sure the robot's initial position is the sa
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/stretch.yaml \
|
||||
--fps 20 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/stretch_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
@@ -56,7 +56,7 @@ python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-overrides max_relative_target=5
|
||||
```
|
||||
|
||||
By adding `--robot-overrides max_relative_target=5`, we override the default value for `max_relative_target` defined in `lerobot/configs/robot/aloha.yaml`. It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot-overrides max_relative_target=null` to the command line:
|
||||
By adding `--robot-overrides max_relative_target=5`, we override the default value for `max_relative_target` defined in `lerobot/configs/robot/aloha.yaml`. It is expected to be `5` to limit the magnitude of the movement for more safety, but the teloperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot-overrides max_relative_target=null` to the command line:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
@@ -84,6 +84,7 @@ python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/aloha_test \
|
||||
--tags aloha tutorial \
|
||||
--warmup-time-s 5 \
|
||||
@@ -103,6 +104,7 @@ echo ${HF_USER}/aloha_test
|
||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/aloha_test
|
||||
```
|
||||
|
||||
@@ -117,6 +119,7 @@ python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/aloha_test \
|
||||
--episode 0
|
||||
```
|
||||
@@ -125,7 +128,7 @@ python lerobot/scripts/control_robot.py replay \
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
dataset_repo_id=${HF_USER}/aloha_test \
|
||||
policy=act_aloha_real \
|
||||
env=aloha_real \
|
||||
@@ -141,6 +144,7 @@ Let's explain it:
|
||||
3. We provided an environment as argument with `env=aloha_real`. This loads configurations from [`lerobot/configs/env/aloha_real.yaml`](../lerobot/configs/env/aloha_real.yaml). Note: this yaml defines 18 dimensions for the `state_dim` and `action_dim`, corresponding to 18 motors, not 14 motors as used in previous Aloha work. This is because, we include the `shoulder_shadow` and `elbow_shadow` motors for simplicity.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`.
|
||||
|
||||
@@ -152,6 +156,7 @@ python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_act_aloha_test \
|
||||
--tags aloha tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
|
||||
@@ -14,10 +14,7 @@ from pathlib import Path
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
)
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
device = torch.device("cuda")
|
||||
@@ -40,44 +37,29 @@ delta_timestamps = {
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to calculate the loss.
|
||||
"action": [
|
||||
-0.1,
|
||||
0.0,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
0.7,
|
||||
0.8,
|
||||
0.9,
|
||||
1.0,
|
||||
1.1,
|
||||
1.2,
|
||||
1.3,
|
||||
1.4,
|
||||
],
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
|
||||
# Load the last 10% of episodes of the dataset as a validation set.
|
||||
# - Load dataset metadata
|
||||
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
||||
# - Calculate train and val episodes
|
||||
total_episodes = dataset_metadata.total_episodes
|
||||
episodes = list(range(dataset_metadata.total_episodes))
|
||||
num_train_episodes = math.floor(total_episodes * 90 / 100)
|
||||
train_episodes = episodes[:num_train_episodes]
|
||||
val_episodes = episodes[num_train_episodes:]
|
||||
print(f"Number of episodes in full dataset: {total_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
|
||||
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
|
||||
# - Load train an val datasets
|
||||
# - Load full dataset
|
||||
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
|
||||
# - Calculate train and val subsets
|
||||
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
|
||||
num_val_episodes = full_dataset.num_episodes - num_train_episodes
|
||||
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
|
||||
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
|
||||
# - Get first frame index of the validation set
|
||||
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
|
||||
# - Load frames subset belonging to validation set using the `split` argument.
|
||||
# It utilizes the `datasets` library's syntax for slicing datasets.
|
||||
# For more information on the Slice API, please see:
|
||||
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||
train_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps
|
||||
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
|
||||
)
|
||||
val_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps
|
||||
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
|
||||
)
|
||||
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
||||
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
||||
|
||||
@@ -1,228 +0,0 @@
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
|
||||
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
|
||||
PUSHT_FEATURES = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": {
|
||||
"axes": ["x", "y"],
|
||||
},
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": {
|
||||
"axes": ["x", "y"],
|
||||
},
|
||||
},
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"next.success": {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"observation.environment_state": {
|
||||
"dtype": "float32",
|
||||
"shape": (16,),
|
||||
"names": [
|
||||
"keypoints",
|
||||
],
|
||||
},
|
||||
"observation.image": {
|
||||
"dtype": None,
|
||||
"shape": (3, 96, 96),
|
||||
"names": [
|
||||
"channel",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def build_features(mode: str) -> dict:
|
||||
features = PUSHT_FEATURES
|
||||
if mode == "keypoints":
|
||||
features.pop("observation.image")
|
||||
else:
|
||||
features.pop("observation.environment_state")
|
||||
features["observation.image"]["dtype"] = mode
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def load_raw_dataset(zarr_path: Path):
|
||||
try:
|
||||
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
print(
|
||||
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
|
||||
)
|
||||
raise e
|
||||
|
||||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
||||
return zarr_data
|
||||
|
||||
|
||||
def calculate_coverage(zarr_data):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
except ModuleNotFoundError as e:
|
||||
print(
|
||||
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
|
||||
)
|
||||
raise e
|
||||
|
||||
block_pos = zarr_data["state"][:, 2:4]
|
||||
block_angle = zarr_data["state"][:, 4]
|
||||
|
||||
num_frames = len(block_pos)
|
||||
|
||||
coverage = np.zeros((num_frames,))
|
||||
# 8 keypoints with 2 coords each
|
||||
keypoints = np.zeros((num_frames, 16))
|
||||
|
||||
# Set x, y, theta (in radians)
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4])
|
||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
space.gravity = 0, 0
|
||||
space.damping = 0
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body, block_shapes = PushTEnv.add_tee(
|
||||
space, block_pos[i].tolist(), block_angle[i].item()
|
||||
)
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage[i] = intersection_area / goal_area
|
||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
||||
|
||||
return coverage, keypoints
|
||||
|
||||
|
||||
def calculate_success(coverage: float, success_threshold: float):
|
||||
return coverage > success_threshold
|
||||
|
||||
|
||||
def calculate_reward(coverage: float, success_threshold: float):
|
||||
return np.clip(coverage / success_threshold, 0, 1)
|
||||
|
||||
|
||||
def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = True):
|
||||
if mode not in ["video", "image", "keypoints"]:
|
||||
raise ValueError(mode)
|
||||
|
||||
if (LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
|
||||
|
||||
zarr_data = load_raw_dataset(zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr")
|
||||
|
||||
env_state = zarr_data["state"][:]
|
||||
agent_pos = env_state[:, :2]
|
||||
|
||||
action = zarr_data["action"][:]
|
||||
image = zarr_data["img"] # (b, h, w, c)
|
||||
|
||||
episode_data_index = {
|
||||
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
|
||||
"to": zarr_data.meta["episode_ends"],
|
||||
}
|
||||
|
||||
# Calculate success and reward based on the overlapping area
|
||||
# of the T-object and the T-area.
|
||||
coverage, keypoints = calculate_coverage(zarr_data)
|
||||
success = calculate_success(coverage, success_threshold=0.95)
|
||||
reward = calculate_reward(coverage, success_threshold=0.95)
|
||||
|
||||
features = build_features(mode)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=10,
|
||||
robot_type="2d pointer",
|
||||
features=features,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
episodes = range(len(episode_data_index["from"]))
|
||||
for ep_idx in episodes:
|
||||
from_idx = episode_data_index["from"][ep_idx]
|
||||
to_idx = episode_data_index["to"][ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
|
||||
for frame_idx in range(num_frames):
|
||||
i = from_idx + frame_idx
|
||||
frame = {
|
||||
"action": torch.from_numpy(action[i]),
|
||||
# Shift reward and success by +1 until the last item of the episode
|
||||
"next.reward": reward[i + (frame_idx < num_frames - 1)],
|
||||
"next.success": success[i + (frame_idx < num_frames - 1)],
|
||||
}
|
||||
|
||||
frame["observation.state"] = torch.from_numpy(agent_pos[i])
|
||||
|
||||
if mode == "keypoints":
|
||||
frame["observation.environment_state"] = torch.from_numpy(keypoints[i])
|
||||
else:
|
||||
frame["observation.image"] = torch.from_numpy(image[i])
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode(task=PUSHT_TASK)
|
||||
|
||||
dataset.consolidate()
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
|
||||
repo_id = "lerobot/pusht"
|
||||
|
||||
modes = ["video", "image", "keypoints"]
|
||||
# Uncomment if you want to try with a specific mode
|
||||
# modes = ["video"]
|
||||
# modes = ["image"]
|
||||
# modes = ["keypoints"]
|
||||
|
||||
raw_dir = Path("data/lerobot-raw/pusht_raw")
|
||||
for mode in modes:
|
||||
if mode in ["image", "keypoints"]:
|
||||
repo_id += f"_{mode}"
|
||||
|
||||
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
|
||||
main(raw_dir, repo_id=repo_id, mode=mode)
|
||||
|
||||
# Uncomment if you want to load the local dataset and explore it
|
||||
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
|
||||
# breakpoint()
|
||||
@@ -181,12 +181,8 @@ available_real_world_datasets = [
|
||||
"lerobot/usc_cloth_sim",
|
||||
]
|
||||
|
||||
available_datasets = sorted(
|
||||
set(
|
||||
itertools.chain(
|
||||
*available_datasets_per_env.values(), available_real_world_datasets
|
||||
)
|
||||
)
|
||||
available_datasets = list(
|
||||
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
|
||||
)
|
||||
|
||||
# lists all available policies from `lerobot/common/policies`
|
||||
@@ -228,13 +224,9 @@ available_policies_per_env = {
|
||||
"dora_aloha_real": ["act_aloha_real"],
|
||||
}
|
||||
|
||||
env_task_pairs = [
|
||||
(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks
|
||||
]
|
||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||
env_dataset_pairs = [
|
||||
(env, dataset)
|
||||
for env, datasets in available_datasets_per_env.items()
|
||||
for dataset in datasets
|
||||
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
|
||||
]
|
||||
env_dataset_policy_triplets = [
|
||||
(env, dataset, policy)
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
---
|
||||
# For reference on dataset card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/datasetcard.md?plain=1
|
||||
# Doc / guide: https://huggingface.co/docs/hub/datasets-cards
|
||||
{{ card_data }}
|
||||
---
|
||||
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
## Dataset Description
|
||||
|
||||
{{ dataset_description | default("", true) }}
|
||||
|
||||
- **Homepage:** {{ url | default("[More Information Needed]", true)}}
|
||||
- **Paper:** {{ paper | default("[More Information Needed]", true)}}
|
||||
- **License:** {{ license | default("[More Information Needed]", true)}}
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
{{ dataset_structure | default("[More Information Needed]", true)}}
|
||||
|
||||
## Citation
|
||||
|
||||
**BibTeX:**
|
||||
|
||||
```bibtex
|
||||
{{ citation_bibtex | default("[More Information Needed]", true)}}
|
||||
```
|
||||
@@ -19,6 +19,9 @@ from math import ceil
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Image
|
||||
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
|
||||
|
||||
def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
@@ -36,29 +39,23 @@ def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
stats_patterns = {}
|
||||
for key, feats_type in dataset.features.items():
|
||||
# NOTE: skip language_instruction embedding in stats computation
|
||||
if key == "language_instruction":
|
||||
continue
|
||||
|
||||
for key in dataset.features:
|
||||
# sanity check that tensors are not float64
|
||||
assert batch[key].dtype != torch.float64
|
||||
|
||||
# if isinstance(feats_type, (VideoFrame, Image)):
|
||||
if key in dataset.meta.camera_keys:
|
||||
if isinstance(feats_type, (VideoFrame, Image)):
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert (
|
||||
c < h and c < w
|
||||
), f"expect channel first images, but instead {batch[key].shape}"
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert (
|
||||
batch[key].dtype == torch.float32
|
||||
), f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert (
|
||||
batch[key].max() <= 1
|
||||
), f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert (
|
||||
batch[key].min() >= 0
|
||||
), f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
elif batch[key].ndim == 2:
|
||||
@@ -66,7 +63,7 @@ def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
elif batch[key].ndim == 1:
|
||||
stats_patterns[key] = "b -> 1"
|
||||
else:
|
||||
raise ValueError(f"{key}, {batch[key].shape}")
|
||||
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
|
||||
|
||||
return stats_patterns
|
||||
|
||||
@@ -106,11 +103,7 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
running_item_count = 0 # for online mean computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(
|
||||
dataloader,
|
||||
total=ceil(max_num_samples / batch_size),
|
||||
desc="Compute mean, min, max",
|
||||
)
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
@@ -125,16 +118,9 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||
mean[key] = (
|
||||
mean[key]
|
||||
+ this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
)
|
||||
max[key] = torch.maximum(
|
||||
max[key], einops.reduce(batch[key], pattern, "max")
|
||||
)
|
||||
min[key] = torch.minimum(
|
||||
min[key], einops.reduce(batch[key], pattern, "min")
|
||||
)
|
||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
@@ -143,9 +129,7 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
running_item_count = 0 # for online std computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(
|
||||
dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std"
|
||||
)
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
@@ -159,9 +143,7 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
# Numerically stable update step for mean computation (where the mean is over squared
|
||||
# residuals).See notes in the mean computation loop above.
|
||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||
std[key] = (
|
||||
std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
)
|
||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
@@ -193,51 +175,39 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
data_keys = set()
|
||||
for dataset in ls_datasets:
|
||||
data_keys.update(dataset.meta.stats.keys())
|
||||
data_keys.update(dataset.stats.keys())
|
||||
stats = {k: {} for k in data_keys}
|
||||
for data_key in data_keys:
|
||||
for stat_key in ["min", "max"]:
|
||||
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
||||
stats[data_key][stat_key] = einops.reduce(
|
||||
torch.stack(
|
||||
[
|
||||
ds.meta.stats[data_key][stat_key]
|
||||
for ds in ls_datasets
|
||||
if data_key in ds.meta.stats
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
|
||||
"n ... -> ...",
|
||||
stat_key,
|
||||
)
|
||||
total_samples = sum(
|
||||
d.num_frames for d in ls_datasets if data_key in d.meta.stats
|
||||
)
|
||||
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
|
||||
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
||||
# dataset, then divide by total_samples to get the overall "mean".
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["mean"] = sum(
|
||||
d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
|
||||
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.meta.stats
|
||||
if data_key in d.stats
|
||||
)
|
||||
# The derivation for standard deviation is a little more involved but is much in the same spirit as
|
||||
# the computation of the mean.
|
||||
# Given two sets of data where the statistics are known:
|
||||
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
|
||||
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["std"] = torch.sqrt(
|
||||
sum(
|
||||
(
|
||||
d.meta.stats[data_key]["std"] ** 2
|
||||
+ (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
|
||||
)
|
||||
* (d.num_frames / total_samples)
|
||||
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
|
||||
* (d.num_samples / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.meta.stats
|
||||
if data_key in d.stats
|
||||
)
|
||||
)
|
||||
return stats
|
||||
|
||||
@@ -74,25 +74,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
|
||||
image_transforms = None
|
||||
if cfg.training.image_transforms.enable:
|
||||
default_tf = OmegaConf.create(
|
||||
{
|
||||
"brightness": {"weight": 0.0, "min_max": None},
|
||||
"contrast": {"weight": 0.0, "min_max": None},
|
||||
"saturation": {"weight": 0.0, "min_max": None},
|
||||
"hue": {"weight": 0.0, "min_max": None},
|
||||
"sharpness": {"weight": 0.0, "min_max": None},
|
||||
"max_num_transforms": None,
|
||||
"random_order": False,
|
||||
"image_size": None,
|
||||
"interpolation": None,
|
||||
"image_mean": None,
|
||||
"image_std": None,
|
||||
}
|
||||
)
|
||||
cfg_tf = OmegaConf.merge(
|
||||
OmegaConf.create(default_tf), cfg.training.image_transforms
|
||||
)
|
||||
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
image_transforms = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
@@ -106,18 +88,12 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
image_size=(cfg_tf.image_size.height, cfg_tf.image_size.width)
|
||||
if cfg_tf.image_size
|
||||
else None,
|
||||
interpolation=cfg_tf.interpolation,
|
||||
image_mean=cfg_tf.image_mean,
|
||||
image_std=cfg_tf.image_std,
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
# TODO (aliberts): add 'episodes' arg from config after removing hydra
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
@@ -125,6 +101,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
else:
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
@@ -135,8 +112,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
for stats_type, listconfig in stats_dict.items():
|
||||
# example of stats_type: min, max, mean, std
|
||||
stats = OmegaConf.to_container(listconfig, resolve=True)
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(
|
||||
stats, dtype=torch.float32
|
||||
)
|
||||
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import multiprocessing
|
||||
import queue
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
|
||||
def safe_stop_image_writer(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
dataset = kwargs.get("dataset")
|
||||
image_writer = getattr(dataset, "image_writer", None) if dataset else None
|
||||
if image_writer is not None:
|
||||
print("Waiting for image writer to terminate...")
|
||||
image_writer.stop()
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
if image_array.dtype != np.uint8:
|
||||
# Assume the image is in [0, 1] range for floating-point data
|
||||
image_array = np.clip(image_array, 0, 1)
|
||||
image_array = (image_array * 255).astype(np.uint8)
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
try:
|
||||
if isinstance(image, np.ndarray):
|
||||
img = image_array_to_image(image)
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
img = image
|
||||
else:
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath)
|
||||
except Exception as e:
|
||||
print(f"Error writing image {fpath}: {e}")
|
||||
|
||||
|
||||
def worker_thread_loop(queue: queue.Queue):
|
||||
while True:
|
||||
item = queue.get()
|
||||
if item is None:
|
||||
queue.task_done()
|
||||
break
|
||||
image_array, fpath = item
|
||||
write_image(image_array, fpath)
|
||||
queue.task_done()
|
||||
|
||||
|
||||
def worker_process(queue: queue.Queue, num_threads: int):
|
||||
threads = []
|
||||
for _ in range(num_threads):
|
||||
t = threading.Thread(target=worker_thread_loop, args=(queue,))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
threads.append(t)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
|
||||
class AsyncImageWriter:
|
||||
"""
|
||||
This class abstract away the initialisation of processes or/and threads to
|
||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||
at a high frame rate.
|
||||
|
||||
When `num_processes=0`, it creates a threads pool of size `num_threads`.
|
||||
When `num_processes>0`, it creates processes pool of size `num_processes`, where each subprocess starts
|
||||
their own threads pool of size `num_threads`.
|
||||
|
||||
The optimal number of processes and threads depends on your computer capabilities.
|
||||
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
|
||||
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
||||
"""
|
||||
|
||||
def __init__(self, num_processes: int = 0, num_threads: int = 1):
|
||||
self.num_processes = num_processes
|
||||
self.num_threads = num_threads
|
||||
self.queue = None
|
||||
self.threads = []
|
||||
self.processes = []
|
||||
self._stopped = False
|
||||
|
||||
if num_threads <= 0 and num_processes <= 0:
|
||||
raise ValueError(
|
||||
"Number of threads and processes must be greater than zero."
|
||||
)
|
||||
|
||||
if self.num_processes == 0:
|
||||
# Use threading
|
||||
self.queue = queue.Queue()
|
||||
for _ in range(self.num_threads):
|
||||
t = threading.Thread(target=worker_thread_loop, args=(self.queue,))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
self.threads.append(t)
|
||||
else:
|
||||
# Use multiprocessing
|
||||
self.queue = multiprocessing.JoinableQueue()
|
||||
for _ in range(self.num_processes):
|
||||
p = multiprocessing.Process(
|
||||
target=worker_process, args=(self.queue, self.num_threads)
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
def save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
|
||||
):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Convert tensor to numpy array to minimize main process time
|
||||
image = image.cpu().numpy()
|
||||
self.queue.put((image, fpath))
|
||||
|
||||
def wait_until_done(self):
|
||||
self.queue.join()
|
||||
|
||||
def stop(self):
|
||||
if self._stopped:
|
||||
return
|
||||
|
||||
if self.num_processes == 0:
|
||||
for _ in self.threads:
|
||||
self.queue.put(None)
|
||||
for t in self.threads:
|
||||
t.join()
|
||||
else:
|
||||
num_nones = self.num_processes * self.num_threads
|
||||
for _ in range(num_nones):
|
||||
self.queue.put(None)
|
||||
for p in self.processes:
|
||||
p.join()
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
self.queue.close()
|
||||
self.queue.join_thread()
|
||||
|
||||
self._stopped = True
|
||||
File diff suppressed because it is too large
Load Diff
@@ -131,9 +131,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
else:
|
||||
self._delta_timestamps = None
|
||||
|
||||
def _make_data_spec(
|
||||
self, data_spec: dict[str, Any], buffer_capacity: int
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
|
||||
"""Makes the data spec for np.memmap."""
|
||||
if any(k.startswith("_") for k in data_spec):
|
||||
raise ValueError(
|
||||
@@ -156,32 +154,14 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
|
||||
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
|
||||
# with real data rather than the dummy initialization.
|
||||
OnlineBuffer.OCCUPANCY_MASK_KEY: {
|
||||
"dtype": np.dtype("?"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.INDEX_KEY: {
|
||||
"dtype": np.dtype("int64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.FRAME_INDEX_KEY: {
|
||||
"dtype": np.dtype("int64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: {
|
||||
"dtype": np.dtype("int64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.TIMESTAMP_KEY: {
|
||||
"dtype": np.dtype("float64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
|
||||
}
|
||||
for k, v in data_spec.items():
|
||||
complete_data_spec[k] = {
|
||||
"dtype": v["dtype"],
|
||||
"shape": (buffer_capacity, *v["shape"]),
|
||||
}
|
||||
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
|
||||
return complete_data_spec
|
||||
|
||||
def add_data(self, data: dict[str, np.ndarray]):
|
||||
@@ -207,10 +187,8 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
|
||||
|
||||
# Shift the incoming indices if necessary.
|
||||
if self.num_frames > 0:
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][
|
||||
next_index - 1
|
||||
]
|
||||
if self.num_samples > 0:
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
||||
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
||||
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
||||
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
|
||||
@@ -245,19 +223,15 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(
|
||||
np.unique(
|
||||
self._data[OnlineBuffer.EPISODE_INDEX_KEY][
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]
|
||||
]
|
||||
)
|
||||
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
)
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
def num_samples(self) -> int:
|
||||
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_frames
|
||||
return self.num_samples
|
||||
|
||||
def _item_to_tensors(self, item: dict) -> dict:
|
||||
item_ = {}
|
||||
@@ -287,9 +261,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
|
||||
)
|
||||
)[0]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][
|
||||
episode_data_indices
|
||||
]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
|
||||
|
||||
for data_key in self.delta_timestamps:
|
||||
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
|
||||
@@ -306,8 +278,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
# Check violated query timestamps are all outside the episode range.
|
||||
assert (
|
||||
(query_ts[is_pad] < episode_timestamps[0])
|
||||
| (episode_timestamps[-1] < query_ts[is_pad])
|
||||
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
|
||||
).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
|
||||
") inside the episode range."
|
||||
@@ -322,9 +293,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
|
||||
def get_data_by_key(self, key: str) -> torch.Tensor:
|
||||
"""Returns all data for a given data key as a Tensor."""
|
||||
return torch.from_numpy(
|
||||
self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]
|
||||
)
|
||||
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
|
||||
|
||||
def compute_sampler_weights(
|
||||
@@ -355,19 +324,13 @@ def compute_sampler_weights(
|
||||
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
|
||||
included here to avoid adding complexity.
|
||||
"""
|
||||
if len(offline_dataset) == 0 and (
|
||||
online_dataset is None or len(online_dataset) == 0
|
||||
):
|
||||
raise ValueError(
|
||||
"At least one of `offline_dataset` or `online_dataset` should be contain data."
|
||||
)
|
||||
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
|
||||
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
|
||||
if (online_dataset is None) ^ (online_sampling_ratio is None):
|
||||
raise ValueError(
|
||||
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
|
||||
)
|
||||
offline_sampling_ratio = (
|
||||
0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
)
|
||||
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
|
||||
weights = []
|
||||
|
||||
|
||||
468
lerobot/common/datasets/populate_dataset.py
Normal file
468
lerobot/common/datasets/populate_dataset.py
Normal file
@@ -0,0 +1,468 @@
|
||||
"""Functions to create an empty dataset, and populate it with frames."""
|
||||
# TODO(rcadene, aliberts): to adapt as class methods of next version of LeRobotDataset
|
||||
|
||||
import concurrent
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.common.utils.utils import log_say
|
||||
from lerobot.scripts.push_dataset_to_hub import (
|
||||
push_dataset_card_to_hub,
|
||||
push_meta_data_to_hub,
|
||||
push_videos_to_hub,
|
||||
save_meta_data,
|
||||
)
|
||||
|
||||
########################################################################################
|
||||
# Asynchrounous saving of images on disk
|
||||
########################################################################################
|
||||
|
||||
|
||||
def safe_stop_image_writer(func):
|
||||
# TODO(aliberts): Allow to pass custom exceptions
|
||||
# (e.g. ThreadServiceExit, KeyboardInterrupt, SystemExit, UnpluggedError, DynamixelCommError)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
image_writer = kwargs.get("dataset", {}).get("image_writer")
|
||||
if image_writer is not None:
|
||||
print("Waiting for image writer to terminate...")
|
||||
stop_image_writer(image_writer, timeout=20)
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def save_image(img_tensor, key, frame_index, episode_index, videos_dir: str):
|
||||
img = Image.fromarray(img_tensor.numpy())
|
||||
path = Path(videos_dir) / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
img.save(str(path), quality=100)
|
||||
|
||||
|
||||
def loop_to_save_images_in_threads(image_queue, num_threads):
|
||||
if num_threads < 1:
|
||||
raise NotImplementedError(f"Only `num_threads>=1` is supported for now, but {num_threads=} given.")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = []
|
||||
while True:
|
||||
# Blocks until a frame is available
|
||||
frame_data = image_queue.get()
|
||||
|
||||
# As usually done, exit loop when receiving None to stop the worker
|
||||
if frame_data is None:
|
||||
break
|
||||
|
||||
image, key, frame_index, episode_index, videos_dir = frame_data
|
||||
futures.append(executor.submit(save_image, image, key, frame_index, episode_index, videos_dir))
|
||||
|
||||
# Before exiting function, wait for all threads to complete
|
||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||
concurrent.futures.wait(futures)
|
||||
progress_bar.update(len(futures))
|
||||
|
||||
|
||||
def start_image_writer_processes(image_queue, num_processes, num_threads_per_process):
|
||||
if num_processes < 1:
|
||||
raise ValueError(f"Only `num_processes>=1` is supported, but {num_processes=} given.")
|
||||
|
||||
if num_threads_per_process < 1:
|
||||
raise NotImplementedError(
|
||||
"Only `num_threads_per_process>=1` is supported for now, but {num_threads_per_process=} given."
|
||||
)
|
||||
|
||||
processes = []
|
||||
for _ in range(num_processes):
|
||||
process = multiprocessing.Process(
|
||||
target=loop_to_save_images_in_threads,
|
||||
args=(image_queue, num_threads_per_process),
|
||||
)
|
||||
process.start()
|
||||
processes.append(process)
|
||||
return processes
|
||||
|
||||
|
||||
def stop_processes(processes, queue, timeout):
|
||||
# Send None to each process to signal them to stop
|
||||
for _ in processes:
|
||||
queue.put(None)
|
||||
|
||||
# Wait maximum 20 seconds for all processes to terminate
|
||||
for process in processes:
|
||||
process.join(timeout=timeout)
|
||||
|
||||
# If not terminated after 20 seconds, force termination
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
|
||||
# Close the queue, no more items can be put in the queue
|
||||
queue.close()
|
||||
|
||||
# Ensure all background queue threads have finished
|
||||
queue.join_thread()
|
||||
|
||||
|
||||
def start_image_writer(num_processes, num_threads):
|
||||
"""This function abstract away the initialisation of processes or/and threads to
|
||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||
at a high frame rate.
|
||||
|
||||
When `num_processes=0`, it returns a dictionary containing a threads pool of size `num_threads`.
|
||||
When `num_processes>0`, it returns a dictionary containing a processes pool of size `num_processes`,
|
||||
where each subprocess starts their own threads pool of size `num_threads`.
|
||||
|
||||
The optimal number of processes and threads depends on your computer capabilities.
|
||||
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
|
||||
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
||||
"""
|
||||
image_writer = {}
|
||||
|
||||
if num_processes == 0:
|
||||
futures = []
|
||||
threads_pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)
|
||||
image_writer["threads_pool"], image_writer["futures"] = threads_pool, futures
|
||||
else:
|
||||
# TODO(rcadene): When using num_processes>1, `multiprocessing.Manager().Queue()`
|
||||
# might be better than `multiprocessing.Queue()`. Source: https://www.geeksforgeeks.org/python-multiprocessing-queue-vs-multiprocessing-manager-queue
|
||||
image_queue = multiprocessing.Queue()
|
||||
processes_pool = start_image_writer_processes(
|
||||
image_queue, num_processes=num_processes, num_threads_per_process=num_threads
|
||||
)
|
||||
image_writer["processes_pool"], image_writer["image_queue"] = processes_pool, image_queue
|
||||
|
||||
return image_writer
|
||||
|
||||
|
||||
def async_save_image(image_writer, image, key, frame_index, episode_index, videos_dir):
|
||||
"""This function abstract away the saving of an image on disk asynchrounously. It uses a dictionary
|
||||
called image writer which contains either a pool of processes or a pool of threads.
|
||||
"""
|
||||
if "threads_pool" in image_writer:
|
||||
threads_pool, futures = image_writer["threads_pool"], image_writer["futures"]
|
||||
futures.append(threads_pool.submit(save_image, image, key, frame_index, episode_index, videos_dir))
|
||||
else:
|
||||
image_queue = image_writer["image_queue"]
|
||||
image_queue.put((image, key, frame_index, episode_index, videos_dir))
|
||||
|
||||
|
||||
def stop_image_writer(image_writer, timeout):
|
||||
if "threads_pool" in image_writer:
|
||||
futures = image_writer["futures"]
|
||||
# Before exiting function, wait for all threads to complete
|
||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||
concurrent.futures.wait(futures, timeout=timeout)
|
||||
progress_bar.update(len(futures))
|
||||
else:
|
||||
processes_pool, image_queue = image_writer["processes_pool"], image_writer["image_queue"]
|
||||
stop_processes(processes_pool, image_queue, timeout=timeout)
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Functions to initialize, resume and populate a dataset
|
||||
########################################################################################
|
||||
|
||||
|
||||
def init_dataset(
|
||||
repo_id,
|
||||
root,
|
||||
force_override,
|
||||
fps,
|
||||
video,
|
||||
write_images,
|
||||
num_image_writer_processes,
|
||||
num_image_writer_threads,
|
||||
):
|
||||
local_dir = Path(root) / repo_id
|
||||
if local_dir.exists() and force_override:
|
||||
shutil.rmtree(local_dir)
|
||||
|
||||
episodes_dir = local_dir / "episodes"
|
||||
episodes_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
videos_dir = local_dir / "videos"
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Logic to resume data recording
|
||||
rec_info_path = episodes_dir / "data_recording_info.json"
|
||||
if rec_info_path.exists():
|
||||
with open(rec_info_path) as f:
|
||||
rec_info = json.load(f)
|
||||
num_episodes = rec_info["last_episode_index"] + 1
|
||||
else:
|
||||
num_episodes = 0
|
||||
|
||||
dataset = {
|
||||
"repo_id": repo_id,
|
||||
"local_dir": local_dir,
|
||||
"videos_dir": videos_dir,
|
||||
"episodes_dir": episodes_dir,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
"rec_info_path": rec_info_path,
|
||||
"num_episodes": num_episodes,
|
||||
}
|
||||
|
||||
if write_images:
|
||||
# Initialize processes or/and threads dedicated to save images on disk asynchronously,
|
||||
# which is critical to control a robot and record data at a high frame rate.
|
||||
image_writer = start_image_writer(
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads,
|
||||
)
|
||||
dataset["image_writer"] = image_writer
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def add_frame(dataset, observation, action):
|
||||
if "current_episode" not in dataset:
|
||||
# initialize episode dictionary
|
||||
ep_dict = {}
|
||||
for key in observation:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
for key in action:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
|
||||
ep_dict["episode_index"] = []
|
||||
ep_dict["frame_index"] = []
|
||||
ep_dict["timestamp"] = []
|
||||
ep_dict["next.done"] = []
|
||||
|
||||
dataset["current_episode"] = ep_dict
|
||||
dataset["current_frame_index"] = 0
|
||||
|
||||
ep_dict = dataset["current_episode"]
|
||||
episode_index = dataset["num_episodes"]
|
||||
frame_index = dataset["current_frame_index"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
video = dataset["video"]
|
||||
fps = dataset["fps"]
|
||||
|
||||
ep_dict["episode_index"].append(episode_index)
|
||||
ep_dict["frame_index"].append(frame_index)
|
||||
ep_dict["timestamp"].append(frame_index / fps)
|
||||
ep_dict["next.done"].append(False)
|
||||
|
||||
img_keys = [key for key in observation if "image" in key]
|
||||
non_img_keys = [key for key in observation if "image" not in key]
|
||||
|
||||
# Save all observed modalities except images
|
||||
for key in non_img_keys:
|
||||
ep_dict[key].append(observation[key])
|
||||
|
||||
# Save actions
|
||||
for key in action:
|
||||
ep_dict[key].append(action[key])
|
||||
|
||||
if "image_writer" not in dataset:
|
||||
dataset["current_frame_index"] += 1
|
||||
return
|
||||
|
||||
# Save images
|
||||
image_writer = dataset["image_writer"]
|
||||
for key in img_keys:
|
||||
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||
async_save_image(
|
||||
image_writer,
|
||||
image=observation[key],
|
||||
key=key,
|
||||
frame_index=frame_index,
|
||||
episode_index=episode_index,
|
||||
videos_dir=str(videos_dir),
|
||||
)
|
||||
|
||||
if video:
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
frame_info = {"path": f"videos/{fname}", "timestamp": frame_index / fps}
|
||||
else:
|
||||
frame_info = str(imgs_dir / f"frame_{frame_index:06d}.png")
|
||||
|
||||
ep_dict[key].append(frame_info)
|
||||
|
||||
dataset["current_frame_index"] += 1
|
||||
|
||||
|
||||
def delete_current_episode(dataset):
|
||||
del dataset["current_episode"]
|
||||
del dataset["current_frame_index"]
|
||||
|
||||
# delete temporary images
|
||||
episode_index = dataset["num_episodes"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
for tmp_imgs_dir in videos_dir.glob(f"*_episode_{episode_index:06d}"):
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
|
||||
def save_current_episode(dataset):
|
||||
episode_index = dataset["num_episodes"]
|
||||
ep_dict = dataset["current_episode"]
|
||||
episodes_dir = dataset["episodes_dir"]
|
||||
rec_info_path = dataset["rec_info_path"]
|
||||
|
||||
ep_dict["next.done"][-1] = True
|
||||
|
||||
for key in ep_dict:
|
||||
if "observation" in key and "image" not in key:
|
||||
ep_dict[key] = torch.stack(ep_dict[key])
|
||||
|
||||
ep_dict["action"] = torch.stack(ep_dict["action"])
|
||||
ep_dict["episode_index"] = torch.tensor(ep_dict["episode_index"])
|
||||
ep_dict["frame_index"] = torch.tensor(ep_dict["frame_index"])
|
||||
ep_dict["timestamp"] = torch.tensor(ep_dict["timestamp"])
|
||||
ep_dict["next.done"] = torch.tensor(ep_dict["next.done"])
|
||||
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||
torch.save(ep_dict, ep_path)
|
||||
|
||||
rec_info = {
|
||||
"last_episode_index": episode_index,
|
||||
}
|
||||
with open(rec_info_path, "w") as f:
|
||||
json.dump(rec_info, f)
|
||||
|
||||
# force re-initialization of episode dictionnary during add_frame
|
||||
del dataset["current_episode"]
|
||||
|
||||
dataset["num_episodes"] += 1
|
||||
|
||||
|
||||
def encode_videos(dataset, image_keys, play_sounds):
|
||||
log_say("Encoding videos", play_sounds)
|
||||
|
||||
num_episodes = dataset["num_episodes"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
local_dir = dataset["local_dir"]
|
||||
fps = dataset["fps"]
|
||||
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||
for key in image_keys:
|
||||
# key = f"observation.images.{name}"
|
||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
video_path = local_dir / "videos" / fname
|
||||
if video_path.exists():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
|
||||
def from_dataset_to_lerobot_dataset(dataset, play_sounds):
|
||||
log_say("Consolidate episodes", play_sounds)
|
||||
|
||||
num_episodes = dataset["num_episodes"]
|
||||
episodes_dir = dataset["episodes_dir"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
video = dataset["video"]
|
||||
fps = dataset["fps"]
|
||||
repo_id = dataset["repo_id"]
|
||||
|
||||
ep_dicts = []
|
||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||
ep_dict = torch.load(ep_path)
|
||||
ep_dicts.append(ep_dict)
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
if video:
|
||||
image_keys = [key for key in data_dict if "image" in key]
|
||||
encode_videos(dataset, image_keys, play_sounds)
|
||||
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
hf_dataset=hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
info=info,
|
||||
videos_dir=videos_dir,
|
||||
)
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
def save_lerobot_dataset_on_disk(lerobot_dataset):
|
||||
hf_dataset = lerobot_dataset.hf_dataset
|
||||
info = lerobot_dataset.info
|
||||
stats = lerobot_dataset.stats
|
||||
episode_data_index = lerobot_dataset.episode_data_index
|
||||
local_dir = lerobot_dataset.videos_dir.parent
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
|
||||
def push_lerobot_dataset_to_hub(lerobot_dataset, tags):
|
||||
hf_dataset = lerobot_dataset.hf_dataset
|
||||
local_dir = lerobot_dataset.videos_dir.parent
|
||||
videos_dir = lerobot_dataset.videos_dir
|
||||
repo_id = lerobot_dataset.repo_id
|
||||
video = lerobot_dataset.video
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
|
||||
if not (local_dir / "train").exists():
|
||||
raise ValueError(
|
||||
"You need to run `save_lerobot_dataset_on_disk(lerobot_dataset)` before pushing to the hub."
|
||||
)
|
||||
|
||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
|
||||
if video:
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
|
||||
|
||||
def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds):
|
||||
if "image_writer" in dataset:
|
||||
logging.info("Waiting for image writer to terminate...")
|
||||
image_writer = dataset["image_writer"]
|
||||
stop_image_writer(image_writer, timeout=20)
|
||||
|
||||
lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds)
|
||||
|
||||
if run_compute_stats:
|
||||
log_say("Computing dataset statistics", play_sounds)
|
||||
lerobot_dataset.stats = compute_stats(lerobot_dataset)
|
||||
else:
|
||||
logging.info("Skipping computation of the dataset statistics")
|
||||
lerobot_dataset.stats = {}
|
||||
|
||||
save_lerobot_dataset_on_disk(lerobot_dataset)
|
||||
|
||||
if push_to_hub:
|
||||
push_lerobot_dataset_to_hub(lerobot_dataset, tags)
|
||||
|
||||
return lerobot_dataset
|
||||
@@ -37,16 +37,10 @@ def check_chunks_compatible(chunks: tuple, shape: tuple):
|
||||
assert c > 0
|
||||
|
||||
|
||||
def rechunk_recompress_array(
|
||||
group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"
|
||||
):
|
||||
def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"):
|
||||
old_arr = group[name]
|
||||
if chunks is None:
|
||||
chunks = (
|
||||
(chunk_length,) + old_arr.chunks[1:]
|
||||
if chunk_length is not None
|
||||
else old_arr.chunks
|
||||
)
|
||||
chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks
|
||||
check_chunks_compatible(chunks, old_arr.shape)
|
||||
|
||||
if compressor is None:
|
||||
@@ -88,18 +82,13 @@ def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=No
|
||||
for i in range(len(shape) - 1):
|
||||
this_chunk_bytes = itemsize * np.prod(rshape[:i])
|
||||
next_chunk_bytes = itemsize * np.prod(rshape[: i + 1])
|
||||
if (
|
||||
this_chunk_bytes <= target_chunk_bytes
|
||||
and next_chunk_bytes > target_chunk_bytes
|
||||
):
|
||||
if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes:
|
||||
split_idx = i
|
||||
|
||||
rchunks = rshape[:split_idx]
|
||||
item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
|
||||
this_max_chunk_length = rshape[split_idx]
|
||||
next_chunk_length = min(
|
||||
this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes)
|
||||
)
|
||||
next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes))
|
||||
rchunks.append(next_chunk_length)
|
||||
len_diff = len(shape) - len(rchunks)
|
||||
rchunks.extend([1] * len_diff)
|
||||
@@ -135,13 +124,7 @@ class ReplayBuffer:
|
||||
root.require_group("data", overwrite=False)
|
||||
meta = root.require_group("meta", overwrite=False)
|
||||
if "episode_ends" not in meta:
|
||||
meta.zeros(
|
||||
"episode_ends",
|
||||
shape=(0,),
|
||||
dtype=np.int64,
|
||||
compressor=None,
|
||||
overwrite=False,
|
||||
)
|
||||
meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False)
|
||||
return cls(root=root)
|
||||
|
||||
@classmethod
|
||||
@@ -210,11 +193,7 @@ class ReplayBuffer:
|
||||
root = zarr.group(store=store)
|
||||
# copy without recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
source=src_store,
|
||||
dest=store,
|
||||
source_path="/meta",
|
||||
dest_path="/meta",
|
||||
if_exists=if_exists,
|
||||
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
||||
)
|
||||
data_group = root.create_group("data", overwrite=True)
|
||||
if keys is None:
|
||||
@@ -222,9 +201,7 @@ class ReplayBuffer:
|
||||
for key in keys:
|
||||
value = src_root["data"][key]
|
||||
cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = cls._resolve_array_compressor(
|
||||
compressors=compressors, key=key, array=value
|
||||
)
|
||||
cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
this_path = "/data/" + key
|
||||
@@ -309,17 +286,13 @@ class ReplayBuffer:
|
||||
meta_group = root.create_group("meta", overwrite=True)
|
||||
# save meta, no chunking
|
||||
for key, value in self.root["meta"].items():
|
||||
_ = meta_group.array(
|
||||
name=key, data=value, shape=value.shape, chunks=value.shape
|
||||
)
|
||||
_ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape)
|
||||
|
||||
# save data, chunk
|
||||
data_group = root.create_group("data", overwrite=True)
|
||||
for key, value in self.root["data"].items():
|
||||
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = self._resolve_array_compressor(
|
||||
compressors=compressors, key=key, array=value
|
||||
)
|
||||
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
if isinstance(value, zarr.Array):
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
@@ -366,19 +339,13 @@ class ReplayBuffer:
|
||||
@staticmethod
|
||||
def resolve_compressor(compressor="default"):
|
||||
if compressor == "default":
|
||||
compressor = numcodecs.Blosc(
|
||||
cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE
|
||||
)
|
||||
compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE)
|
||||
elif compressor == "disk":
|
||||
compressor = numcodecs.Blosc(
|
||||
"zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE
|
||||
)
|
||||
compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE)
|
||||
return compressor
|
||||
|
||||
@classmethod
|
||||
def _resolve_array_compressor(
|
||||
cls, compressors: dict | str | numcodecs.abc.Codec, key, array
|
||||
):
|
||||
def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array):
|
||||
# allows compressor to be explicitly set to None
|
||||
cpr = "nil"
|
||||
if isinstance(compressors, dict):
|
||||
@@ -437,11 +404,7 @@ class ReplayBuffer:
|
||||
if self.backend == "zarr":
|
||||
for key, value in np_data.items():
|
||||
_ = meta_group.array(
|
||||
name=key,
|
||||
data=value,
|
||||
shape=value.shape,
|
||||
chunks=value.shape,
|
||||
overwrite=True,
|
||||
name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True
|
||||
)
|
||||
else:
|
||||
meta_group.update(np_data)
|
||||
@@ -551,18 +514,10 @@ class ReplayBuffer:
|
||||
# create array
|
||||
if key not in self.data:
|
||||
if is_zarr:
|
||||
cks = self._resolve_array_chunks(
|
||||
chunks=chunks, key=key, array=value
|
||||
)
|
||||
cpr = self._resolve_array_compressor(
|
||||
compressors=compressors, key=key, array=value
|
||||
)
|
||||
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
arr = self.data.zeros(
|
||||
name=key,
|
||||
shape=new_shape,
|
||||
chunks=cks,
|
||||
dtype=value.dtype,
|
||||
compressor=cpr,
|
||||
name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr
|
||||
)
|
||||
else:
|
||||
# copy data to prevent modify
|
||||
@@ -589,9 +544,7 @@ class ReplayBuffer:
|
||||
|
||||
# rechunk
|
||||
if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]:
|
||||
rechunk_recompress_array(
|
||||
self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5)
|
||||
)
|
||||
rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5))
|
||||
|
||||
def drop_episode(self):
|
||||
is_zarr = self.backend == "zarr"
|
||||
|
||||
@@ -38,9 +38,7 @@ import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import (
|
||||
AVAILABLE_RAW_REPO_IDS,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import AVAILABLE_RAW_REPO_IDS
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
|
||||
|
||||
@@ -75,9 +73,7 @@ def encode_datasets(
|
||||
check_repo_id(raw_repo_id)
|
||||
dataset_repo_id_push = get_push_repo_id_from_raw(raw_repo_id, push_repo)
|
||||
dataset_raw_dir = raw_dir / raw_repo_id
|
||||
dataset_dir = (
|
||||
local_dir / dataset_repo_id_push if local_dir is not None else None
|
||||
)
|
||||
dataset_dir = local_dir / dataset_repo_id_push if local_dir is not None else None
|
||||
encoding = {
|
||||
"vcodec": vcodec,
|
||||
"pix_fmt": pix_fmt,
|
||||
|
||||
@@ -133,9 +133,7 @@ class Jpeg2k(Codec):
|
||||
)
|
||||
|
||||
def decode(self, buf, out=None):
|
||||
return imagecodecs.jpeg2k_decode(
|
||||
buf, verbose=self.verbose, numthreads=self.numthreads, out=out
|
||||
)
|
||||
return imagecodecs.jpeg2k_decode(buf, verbose=self.verbose, numthreads=self.numthreads, out=out)
|
||||
|
||||
|
||||
class JpegXl(Codec):
|
||||
|
||||
@@ -30,12 +30,12 @@ from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
@@ -44,9 +44,7 @@ from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
def get_cameras(hdf5_data):
|
||||
# ignore depth channel, not currently handled
|
||||
# TODO(rcadene): add depth
|
||||
rgb_cameras = [
|
||||
key for key in hdf5_data["/observations/images"].keys() if "depth" not in key
|
||||
] # noqa: SIM118
|
||||
rgb_cameras = [key for key in hdf5_data["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
|
||||
return rgb_cameras
|
||||
|
||||
|
||||
@@ -75,9 +73,7 @@ def check_format(raw_dir) -> bool:
|
||||
else:
|
||||
assert data[f"/observations/images/{camera}"].ndim == 4
|
||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||
assert (
|
||||
c < h and c < w
|
||||
), f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
@@ -138,17 +134,14 @@ def load_from_raw(
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(
|
||||
tmp_imgs_dir, video_path, fps, **(encoding or {})
|
||||
)
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps}
|
||||
for i in range(num_frames)
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
@@ -188,18 +181,15 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features[key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
|
||||
@@ -24,11 +24,8 @@ from datasets import Dataset, Features, Image, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
)
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
|
||||
|
||||
|
||||
@@ -26,10 +26,8 @@ import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
@@ -44,19 +42,11 @@ def check_format(raw_dir) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
):
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
# Load data stream that will be used as reference for the timestamps synchronization
|
||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||
if len(reference_files) == 0:
|
||||
raise ValueError(
|
||||
f"Missing reference files for camera, starting with in '{raw_dir}'"
|
||||
)
|
||||
raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
|
||||
# select first camera in alphanumeric order
|
||||
reference_key = sorted(reference_files)[0].stem
|
||||
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
|
||||
@@ -117,9 +107,7 @@ def load_from_raw(
|
||||
|
||||
df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
|
||||
# each episode starts with timestamp 0 to match the ones from the video
|
||||
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(
|
||||
lambda x: x - x.iloc[0]
|
||||
)
|
||||
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
|
||||
|
||||
del df["timestamp_utc"]
|
||||
|
||||
@@ -132,9 +120,7 @@ def load_from_raw(
|
||||
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
||||
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
||||
if ep_ids != expected_ep_ids:
|
||||
raise ValueError(
|
||||
f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}"
|
||||
)
|
||||
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
|
||||
|
||||
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
||||
videos_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -166,9 +152,7 @@ def load_from_raw(
|
||||
data_dict[key] = torch.from_numpy(df[key].values)
|
||||
# is vector
|
||||
elif df[key].iloc[0].shape[0] > 1:
|
||||
data_dict[key] = torch.stack(
|
||||
[torch.from_numpy(x.copy()) for x in df[key].values]
|
||||
)
|
||||
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
|
||||
else:
|
||||
raise ValueError(key)
|
||||
|
||||
@@ -186,18 +170,15 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features[key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
|
||||
639
lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml
Normal file
639
lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml
Normal file
@@ -0,0 +1,639 @@
|
||||
OPENX_DATASET_CONFIGS:
|
||||
fractal20220817_data:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- base_pose_tool_reached
|
||||
- gripper_closed
|
||||
fps: 3
|
||||
|
||||
kuka:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- clip_function_input/base_pose_tool_reached
|
||||
- gripper_closed
|
||||
fps: 10
|
||||
|
||||
bridge_openx:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- EEF_state
|
||||
- gripper_state
|
||||
fps: 5
|
||||
|
||||
taco_play:
|
||||
image_obs_keys:
|
||||
- rgb_static
|
||||
- rgb_gripper
|
||||
depth_obs_keys:
|
||||
- depth_static
|
||||
- depth_gripper
|
||||
state_obs_keys:
|
||||
- state_eef
|
||||
- state_gripper
|
||||
fps: 15
|
||||
|
||||
jaco_play:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image_wrist
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state_eef
|
||||
- state_gripper
|
||||
fps: 10
|
||||
|
||||
berkeley_cable_routing:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- top_image
|
||||
- wrist45_image
|
||||
- wrist225_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- robot_state
|
||||
fps: 10
|
||||
|
||||
roboturk:
|
||||
image_obs_keys:
|
||||
- front_rgb
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 10
|
||||
|
||||
nyu_door_opening_surprising_effectiveness:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 3
|
||||
|
||||
viola:
|
||||
image_obs_keys:
|
||||
- agentview_rgb
|
||||
- eye_in_hand_rgb
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- joint_states
|
||||
- gripper_states
|
||||
fps: 20
|
||||
|
||||
berkeley_autolab_ur5:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- hand_image
|
||||
depth_obs_keys:
|
||||
- image_with_depth
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 5
|
||||
|
||||
toto:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 30
|
||||
|
||||
language_table:
|
||||
image_obs_keys:
|
||||
- rgb
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- effector_translation
|
||||
fps: 10
|
||||
|
||||
columbia_cairlab_pusht_real:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- robot_state
|
||||
fps: 10
|
||||
|
||||
stanford_kuka_multimodal_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- depth_image
|
||||
state_obs_keys:
|
||||
- ee_position
|
||||
- ee_orientation
|
||||
fps: 20
|
||||
|
||||
nyu_rot_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 3
|
||||
|
||||
io_ai_tech:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image_fisheye
|
||||
- image_left_side
|
||||
- image_right_side
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 3
|
||||
|
||||
stanford_hydra_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
austin_buds_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 20
|
||||
|
||||
nyu_franka_play_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image_additional_view
|
||||
depth_obs_keys:
|
||||
- depth
|
||||
- depth_additional_view
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
fps: 3
|
||||
|
||||
maniskill_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- depth
|
||||
- wrist_depth
|
||||
state_obs_keys:
|
||||
- tcp_pose
|
||||
- gripper_state
|
||||
fps: 20
|
||||
|
||||
furniture_bench_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
cmu_franka_exploration_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- highres_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 10
|
||||
|
||||
ucsd_kitchen_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- joint_state
|
||||
fps: 2
|
||||
|
||||
ucsd_pick_and_place_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 3
|
||||
|
||||
spoc:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image_manipulation
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 3
|
||||
|
||||
austin_sailor_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 20
|
||||
|
||||
austin_sirius_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 20
|
||||
|
||||
bc_z:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- present/xyz
|
||||
- present/axis_angle
|
||||
- present/sensed_close
|
||||
fps: 10
|
||||
|
||||
utokyo_pr2_opening_fridge_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
utokyo_xarm_pick_and_place_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image2
|
||||
- hand_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- end_effector_pose
|
||||
fps: 10
|
||||
|
||||
utokyo_xarm_bimanual_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- pose_r
|
||||
fps: 10
|
||||
|
||||
robo_net:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image1
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 1
|
||||
|
||||
robo_set:
|
||||
image_obs_keys:
|
||||
- image_left
|
||||
- image_right
|
||||
- image_wrist
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
- state_velocity
|
||||
fps: 5
|
||||
|
||||
berkeley_mvp_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- hand_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- gripper
|
||||
- pose
|
||||
- joint_pos
|
||||
fps: 5
|
||||
|
||||
berkeley_rpt_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- hand_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- joint_pos
|
||||
- gripper
|
||||
fps: 30
|
||||
|
||||
kaist_nonprehensile_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
stanford_mask_vit_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
|
||||
tokyo_u_lsmo_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
dlr_sara_pour_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
dlr_sara_grid_clamp_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
dlr_edan_shared_control_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 5
|
||||
|
||||
asu_table_top_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 12.5
|
||||
|
||||
stanford_robocook_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image_1
|
||||
- image_2
|
||||
depth_obs_keys:
|
||||
- depth_1
|
||||
- depth_2
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 5
|
||||
|
||||
imperialcollege_sawyer_wrist_cam:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
iamlab_cmu_pickup_insert_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- joint_state
|
||||
- gripper_state
|
||||
fps: 20
|
||||
|
||||
uiuc_d3field:
|
||||
image_obs_keys:
|
||||
- image_1
|
||||
- image_2
|
||||
depth_obs_keys:
|
||||
- depth_1
|
||||
- depth_2
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 1
|
||||
|
||||
utaustin_mutex:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 20
|
||||
|
||||
berkeley_fanuc_manipulation:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- joint_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
cmu_playing_with_food:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- finger_vision_1
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
cmu_play_fusion:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 5
|
||||
|
||||
cmu_stretch:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
berkeley_gnm_recon:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
- position
|
||||
- yaw
|
||||
fps: 3
|
||||
|
||||
berkeley_gnm_cory_hall:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
- position
|
||||
- yaw
|
||||
fps: 5
|
||||
|
||||
berkeley_gnm_sac_son:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
- position
|
||||
- yaw
|
||||
fps: 10
|
||||
|
||||
droid:
|
||||
image_obs_keys:
|
||||
- exterior_image_1_left
|
||||
- exterior_image_2_left
|
||||
- wrist_image_left
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- proprio
|
||||
fps: 15
|
||||
|
||||
droid_100:
|
||||
image_obs_keys:
|
||||
- exterior_image_1_left
|
||||
- exterior_image_2_left
|
||||
- wrist_image_left
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- proprio
|
||||
fps: 15
|
||||
|
||||
fmb:
|
||||
image_obs_keys:
|
||||
- image_side_1
|
||||
- image_side_2
|
||||
- image_wrist_1
|
||||
- image_wrist_2
|
||||
depth_obs_keys:
|
||||
- image_side_1_depth
|
||||
- image_side_2_depth
|
||||
- image_wrist_1_depth
|
||||
- image_wrist_2_depth
|
||||
state_obs_keys:
|
||||
- proprio
|
||||
fps: 10
|
||||
|
||||
dobbe:
|
||||
image_obs_keys:
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- proprio
|
||||
fps: 3.75
|
||||
|
||||
usc_cloth_sim_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 10
|
||||
|
||||
plex_robosuite:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 20
|
||||
|
||||
conq_hose_manipulation:
|
||||
image_obs_keys:
|
||||
- frontleft_fisheye_image
|
||||
- frontright_fisheye_image
|
||||
- hand_color_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 30
|
||||
106
lerobot/common/datasets/push_dataset_to_hub/openx/data_utils.py
Normal file
106
lerobot/common/datasets/push_dataset_to_hub/openx/data_utils.py
Normal file
@@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the Licens e.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
NOTE(YL): Adapted from:
|
||||
Octo: https://github.com/octo-models/octo/blob/main/octo/data/utils/data_utils.py
|
||||
|
||||
data_utils.py
|
||||
|
||||
Additional utils for data processing.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Converts gripper actions from continuous to binary values (0 and 1).
|
||||
|
||||
We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it
|
||||
transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate
|
||||
values based on the state that is reached _after_ those intermediate values.
|
||||
|
||||
In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that
|
||||
chunk of intermediate values as the last action in the trajectory.
|
||||
|
||||
The `scan_fn` implements the following logic:
|
||||
new_actions = np.empty_like(actions)
|
||||
carry = actions[-1]
|
||||
for i in reversed(range(actions.shape[0])):
|
||||
if in_between_mask[i]:
|
||||
carry = carry
|
||||
else:
|
||||
carry = float(open_mask[i])
|
||||
new_actions[i] = carry
|
||||
"""
|
||||
open_mask, closed_mask = actions > 0.95, actions < 0.05
|
||||
in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask))
|
||||
is_open_float = tf.cast(open_mask, tf.float32)
|
||||
|
||||
def scan_fn(carry, i):
|
||||
return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i])
|
||||
|
||||
return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True)
|
||||
|
||||
|
||||
def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||
return 1 - actions
|
||||
|
||||
|
||||
def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open).
|
||||
|
||||
Assumes that the first relative gripper is not redundant (i.e. close when already closed)!
|
||||
"""
|
||||
# Note =>> -1 for closing, 1 for opening, 0 for no change
|
||||
opening_mask, closing_mask = actions < -0.1, actions > 0.1
|
||||
thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0))
|
||||
|
||||
def scan_fn(carry, i):
|
||||
return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i])
|
||||
|
||||
# If no relative grasp, assumes open for whole trajectory
|
||||
start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)]
|
||||
start = tf.cond(start == 0, lambda: 1, lambda: start)
|
||||
|
||||
# Note =>> -1 for closed, 1 for open
|
||||
new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start)
|
||||
new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5
|
||||
|
||||
return new_actions
|
||||
|
||||
|
||||
# === Bridge-V2 =>> Dataset-Specific Transform ===
|
||||
def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
|
||||
movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6]
|
||||
traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj)
|
||||
traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1)
|
||||
|
||||
return traj_truncated
|
||||
|
||||
|
||||
# === RLDS Dataset Initialization Utilities ===
|
||||
def pprint_data_mixture(dataset_kwargs_list: List[Dict[str, Any]], dataset_weights: List[int]) -> None:
|
||||
print("\n######################################################################################")
|
||||
print(f"# Loading the following {len(dataset_kwargs_list)} datasets (incl. sampling weight):{'': >24} #")
|
||||
for dataset_kwargs, weight in zip(dataset_kwargs_list, dataset_weights, strict=False):
|
||||
pad = 80 - len(dataset_kwargs["name"])
|
||||
print(f"# {dataset_kwargs['name']}: {weight:=>{pad}f} #")
|
||||
print("######################################################################################\n")
|
||||
200
lerobot/common/datasets/push_dataset_to_hub/openx/droid_utils.py
Normal file
200
lerobot/common/datasets/push_dataset_to_hub/openx/droid_utils.py
Normal file
@@ -0,0 +1,200 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
NOTE(YL): Adapted from:
|
||||
OpenVLA: https://github.com/openvla/openvla
|
||||
|
||||
Episode transforms for DROID dataset.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_graphics.geometry.transformation as tfg
|
||||
|
||||
|
||||
def rmat_to_euler(rot_mat):
|
||||
return tfg.euler.from_rotation_matrix(rot_mat)
|
||||
|
||||
|
||||
def euler_to_rmat(euler):
|
||||
return tfg.rotation_matrix_3d.from_euler(euler)
|
||||
|
||||
|
||||
def invert_rmat(rot_mat):
|
||||
return tfg.rotation_matrix_3d.inverse(rot_mat)
|
||||
|
||||
|
||||
def rotmat_to_rot6d(mat):
|
||||
"""
|
||||
Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix).
|
||||
Args:
|
||||
mat: rotation matrix
|
||||
|
||||
Returns: 6d vector (first two rows of rotation matrix)
|
||||
|
||||
"""
|
||||
r6 = mat[..., :2, :]
|
||||
r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :]
|
||||
r6_flat = tf.concat([r6_0, r6_1], axis=-1)
|
||||
return r6_flat
|
||||
|
||||
|
||||
def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame):
|
||||
"""
|
||||
Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame.
|
||||
Args:
|
||||
velocity: 6d velocity action (3 x translation, 3 x rotation)
|
||||
wrist_in_robot_frame: 6d pose of the end-effector in robot base frame
|
||||
|
||||
Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6)
|
||||
|
||||
"""
|
||||
r_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6])
|
||||
r_frame_inv = invert_rmat(r_frame)
|
||||
|
||||
# world to wrist: dT_pi = R^-1 dT_rbt
|
||||
vel_t = (r_frame_inv @ velocity[:, :3][..., None])[..., 0]
|
||||
|
||||
# world to wrist: dR_pi = R^-1 dR_rbt R
|
||||
dr_ = euler_to_rmat(velocity[:, 3:6])
|
||||
dr_ = r_frame_inv @ (dr_ @ r_frame)
|
||||
dr_r6 = rotmat_to_rot6d(dr_)
|
||||
return tf.concat([vel_t, dr_r6], axis=-1)
|
||||
|
||||
|
||||
def rand_swap_exterior_images(img1, img2):
|
||||
"""
|
||||
Randomly swaps the two exterior images (for training with single exterior input).
|
||||
"""
|
||||
return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1))
|
||||
|
||||
|
||||
def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
||||
"""
|
||||
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
||||
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
dt,
|
||||
dr_,
|
||||
1 - trajectory["action_dict"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
|
||||
rand_swap_exterior_images(
|
||||
trajectory["observation"]["exterior_image_1_left"],
|
||||
trajectory["observation"]["exterior_image_2_left"],
|
||||
)
|
||||
)
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["cartesian_position"],
|
||||
trajectory["observation"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
DROID dataset transformation for actions expressed in *wrist* frame of the robot.
|
||||
"""
|
||||
wrist_act = velocity_act_to_wrist_frame(
|
||||
trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"]
|
||||
)
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
wrist_act,
|
||||
trajectory["action_dict"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
|
||||
rand_swap_exterior_images(
|
||||
trajectory["observation"]["exterior_image_1_left"],
|
||||
trajectory["observation"]["exterior_image_2_left"],
|
||||
)
|
||||
)
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["cartesian_position"],
|
||||
trajectory["observation"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
||||
"""
|
||||
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
||||
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
dt,
|
||||
dr_,
|
||||
1 - trajectory["action_dict"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["cartesian_position"],
|
||||
trajectory["observation"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def zero_action_filter(traj: Dict) -> bool:
|
||||
"""
|
||||
Filters transitions whose actions are all-0 (only relative actions, no gripper action).
|
||||
Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
|
||||
"""
|
||||
droid_q01 = tf.convert_to_tensor(
|
||||
[
|
||||
-0.7776297926902771,
|
||||
-0.5803514122962952,
|
||||
-0.5795090794563293,
|
||||
-0.6464047729969025,
|
||||
-0.7041108310222626,
|
||||
-0.8895104378461838,
|
||||
]
|
||||
)
|
||||
droid_q99 = tf.convert_to_tensor(
|
||||
[
|
||||
0.7597932070493698,
|
||||
0.5726242214441299,
|
||||
0.7351000607013702,
|
||||
0.6705610305070877,
|
||||
0.6464948207139969,
|
||||
0.8897542208433151,
|
||||
]
|
||||
)
|
||||
droid_norm_0_act = (
|
||||
2 * (tf.zeros_like(traj["action"][:, :6]) - droid_q01) / (droid_q99 - droid_q01 + 1e-8) - 1
|
||||
)
|
||||
|
||||
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - droid_norm_0_act) > 1e-5)
|
||||
859
lerobot/common/datasets/push_dataset_to_hub/openx/transforms.py
Normal file
859
lerobot/common/datasets/push_dataset_to_hub/openx/transforms.py
Normal file
@@ -0,0 +1,859 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
NOTE(YL): Adapted from:
|
||||
OpenVLA: https://github.com/openvla/openvla
|
||||
Octo: https://github.com/octo-models/octo
|
||||
|
||||
transforms.py
|
||||
|
||||
Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment.
|
||||
|
||||
Transforms adopt the following structure:
|
||||
Input: Dictionary of *batched* features (i.e., has leading time dimension)
|
||||
Output: Dictionary `step` =>> {
|
||||
"observation": {
|
||||
<image_keys, depth_image_keys>
|
||||
State (in chosen state representation)
|
||||
},
|
||||
"action": Action (in chosen action representation),
|
||||
"language_instruction": str
|
||||
}
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.openx.data_utils import (
|
||||
binarize_gripper_actions,
|
||||
invert_gripper_actions,
|
||||
rel2abs_gripper_actions,
|
||||
relabel_bridge_actions,
|
||||
)
|
||||
|
||||
|
||||
def droid_baseact_transform_fn():
|
||||
from lerobot.common.datasets.push_dataset_to_hub.openx.droid_utils import droid_baseact_transform
|
||||
|
||||
return droid_baseact_transform
|
||||
|
||||
|
||||
def bridge_openx_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Applies to version of Bridge V2 in Open X-Embodiment mixture.
|
||||
|
||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
||||
"""
|
||||
for key in trajectory:
|
||||
if key == "traj_metadata":
|
||||
continue
|
||||
elif key in ["observation", "action"]:
|
||||
for key2 in trajectory[key]:
|
||||
trajectory[key][key2] = trajectory[key][key2][1:]
|
||||
else:
|
||||
trajectory[key] = trajectory[key][1:]
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
trajectory = relabel_bridge_actions(trajectory)
|
||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Applies to original version of Bridge V2 from the official project website.
|
||||
|
||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
||||
"""
|
||||
for key in trajectory:
|
||||
if key == "traj_metadata":
|
||||
continue
|
||||
elif key == "observation":
|
||||
for key2 in trajectory[key]:
|
||||
trajectory[key][key2] = trajectory[key][key2][1:]
|
||||
else:
|
||||
trajectory[key] = trajectory[key][1:]
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
[
|
||||
trajectory["action"][:, :6],
|
||||
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
trajectory = relabel_bridge_actions(trajectory)
|
||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
[
|
||||
trajectory["action"][:, :6],
|
||||
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
# decode compressed state
|
||||
eef_value = tf.io.decode_compressed(
|
||||
trajectory["observation"]["clip_function_input/base_pose_tool_reached"],
|
||||
compression_type="ZLIB",
|
||||
)
|
||||
eef_value = tf.io.decode_raw(eef_value, tf.float32)
|
||||
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
|
||||
gripper_value = tf.io.decode_compressed(
|
||||
trajectory["observation"]["gripper_closed"], compression_type="ZLIB"
|
||||
)
|
||||
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
|
||||
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6]
|
||||
trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8]
|
||||
trajectory["action"] = trajectory["action"]["rel_actions_world"]
|
||||
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
tf.clip_by_value(trajectory["action"][:, -1:], 0, 1),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
|
||||
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][
|
||||
:, -1:
|
||||
]
|
||||
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
tf.zeros_like(trajectory["action"]["world_vector"]),
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
tf.zeros_like(trajectory["action"]["world_vector"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert absolute gripper action, +1 = open, 0 = close
|
||||
gripper_action = invert_gripper_actions(
|
||||
tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)
|
||||
)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action,
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
trajectory["language_embedding"] = trajectory["observation"]["natural_language_embedding"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, None]
|
||||
gripper_action = tf.clip_by_value(gripper_action, 0, 1)
|
||||
gripper_action = invert_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action,
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14]
|
||||
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# default to "open" gripper
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.ones_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# decode language instruction
|
||||
instruction_bytes = trajectory["observation"]["instruction"]
|
||||
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
|
||||
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
|
||||
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[
|
||||
:, 0
|
||||
]
|
||||
return trajectory
|
||||
|
||||
|
||||
def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
trajectory["action"]["gripper_closedness_action"][:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tf.zeros_like(trajectory["action"][:, :3]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :7]
|
||||
return trajectory
|
||||
|
||||
|
||||
def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(trajectory["action"][:, -1:]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :3],
|
||||
trajectory["observation"]["state"][:, 7:10],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2]
|
||||
return trajectory
|
||||
|
||||
|
||||
def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
|
||||
return trajectory
|
||||
|
||||
|
||||
def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32)
|
||||
trajectory["observation"]["depth_additional_view"] = tf.cast(
|
||||
trajectory["observation"]["depth_additional_view"][..., 0], tf.float32
|
||||
)
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:]
|
||||
|
||||
# clip gripper action, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, -8:-2],
|
||||
tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8]
|
||||
return trajectory
|
||||
|
||||
|
||||
def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import tensorflow_graphics.geometry.transformation as tft
|
||||
|
||||
trajectory["observation"]["state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :7],
|
||||
trajectory["observation"]["state"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tf.zeros_like(trajectory["action"][:, :3]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["future/xyz_residual"][:, :3],
|
||||
trajectory["action"]["future/axis_angle_residual"][:, :3],
|
||||
invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = trajectory["action"][..., -7:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :4],
|
||||
tf.zeros_like(trajectory["observation"]["state"][:, :2]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :4],
|
||||
tf.zeros_like(trajectory["action"][:, :2]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
trajectory["observation"]["state"] = tf.concat((
|
||||
tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32),
|
||||
trajectory["observation"]["pose"],
|
||||
trajectory["observation"]["joint_pos"],),
|
||||
axis=-1,)
|
||||
"""
|
||||
trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32)
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32)
|
||||
return trajectory
|
||||
|
||||
|
||||
def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
tf.zeros_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["end_effector_pose"][:, :4],
|
||||
tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :4],
|
||||
tf.zeros_like(trajectory["action"][:, :2]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6]
|
||||
return trajectory
|
||||
|
||||
|
||||
def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(trajectory["action"][:, -1:]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import tensorflow_graphics.geometry.transformation as tft
|
||||
|
||||
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
||||
trajectory["action"][:, 7:8],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
|
||||
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7]
|
||||
|
||||
# dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
invert_gripper_actions(trajectory["observation"]["gripper_state"]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import tensorflow_graphics.geometry.transformation as tft
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
trajectory["action"][:, -4:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :3],
|
||||
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["position"],
|
||||
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
|
||||
trajectory["observation"]["yaw"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def fmb_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# every input feature is batched, ie has leading batch dimension
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["eef_pose"],
|
||||
trajectory["observation"]["state_gripper_pose"][..., None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# every input feature is batched, ie has leading batch dimension
|
||||
trajectory["observation"]["proprio"] = trajectory["observation"]["state"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def robo_set_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# gripper action is in -1...1 --> clip to 0...1, flip
|
||||
gripper_action = trajectory["action"][:, -1:]
|
||||
gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :7],
|
||||
gripper_action,
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def identity_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return trajectory
|
||||
|
||||
|
||||
# === Registry ===
|
||||
OPENX_STANDARDIZATION_TRANSFORMS = {
|
||||
"bridge_openx": bridge_openx_dataset_transform,
|
||||
"bridge_orig": bridge_orig_dataset_transform,
|
||||
"bridge_dataset": bridge_orig_dataset_transform,
|
||||
"ppgm": ppgm_dataset_transform,
|
||||
"ppgm_static": ppgm_dataset_transform,
|
||||
"ppgm_wrist": ppgm_dataset_transform,
|
||||
"fractal20220817_data": rt1_dataset_transform,
|
||||
"kuka": kuka_dataset_transform,
|
||||
"taco_play": taco_play_dataset_transform,
|
||||
"jaco_play": jaco_play_dataset_transform,
|
||||
"berkeley_cable_routing": berkeley_cable_routing_dataset_transform,
|
||||
"roboturk": roboturk_dataset_transform,
|
||||
"nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform,
|
||||
"viola": viola_dataset_transform,
|
||||
"berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform,
|
||||
"toto": toto_dataset_transform,
|
||||
"language_table": language_table_dataset_transform,
|
||||
"columbia_cairlab_pusht_real": pusht_dataset_transform,
|
||||
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform,
|
||||
"nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform,
|
||||
"stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform,
|
||||
"austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform,
|
||||
"nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform,
|
||||
"maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform,
|
||||
"furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform,
|
||||
"cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform,
|
||||
"ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform,
|
||||
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform,
|
||||
"austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform,
|
||||
"austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform,
|
||||
"bc_z": bc_z_dataset_transform,
|
||||
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform,
|
||||
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform,
|
||||
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": identity_transform,
|
||||
"utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform,
|
||||
"robo_net": robo_net_dataset_transform,
|
||||
"berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform,
|
||||
"berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform,
|
||||
"kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform,
|
||||
"stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform,
|
||||
"tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform,
|
||||
"dlr_sara_pour_converted_externally_to_rlds": identity_transform,
|
||||
"dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform,
|
||||
"dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform,
|
||||
"asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform,
|
||||
"stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform,
|
||||
"imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform,
|
||||
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform,
|
||||
"uiuc_d3field": uiuc_d3field_dataset_transform,
|
||||
"utaustin_mutex": utaustin_mutex_dataset_transform,
|
||||
"berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform,
|
||||
"cmu_playing_with_food": cmu_playing_with_food_dataset_transform,
|
||||
"cmu_play_fusion": playfusion_dataset_transform,
|
||||
"cmu_stretch": cmu_stretch_dataset_transform,
|
||||
"berkeley_gnm_recon": gnm_dataset_transform,
|
||||
"berkeley_gnm_cory_hall": gnm_dataset_transform,
|
||||
"berkeley_gnm_sac_son": gnm_dataset_transform,
|
||||
"droid": droid_baseact_transform_fn(),
|
||||
"droid_100": droid_baseact_transform_fn(), # first 100 episodes of droid
|
||||
"fmb": fmb_transform,
|
||||
"dobbe": dobbe_dataset_transform,
|
||||
"robo_set": robo_set_dataset_transform,
|
||||
"usc_cloth_sim_converted_externally_to_rlds": identity_transform,
|
||||
"plex_robosuite": identity_transform,
|
||||
"conq_hose_manipulation": identity_transform,
|
||||
"io_ai_tech": identity_transform,
|
||||
"spoc": identity_transform,
|
||||
}
|
||||
@@ -14,16 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
For all datasets in the RLDS format.
|
||||
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
||||
|
||||
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
|
||||
|
||||
Example:
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir /path/to/data/bridge_dataset/1.0.0/ \
|
||||
--repo-id your_hub/sampled_bridge_data_v2 \
|
||||
--raw-format rlds \
|
||||
--raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \
|
||||
--repo-id youliangtan/sampled_bridge_data_v2 \
|
||||
--raw-format openx_rlds.bridge_orig \
|
||||
--episodes 3 4 5 8 9
|
||||
|
||||
Exact dataset fps defined in openx/config.py, obtained from:
|
||||
@@ -38,21 +35,28 @@ import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
import torch
|
||||
import tqdm
|
||||
import yaml
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
with open("lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml") as f:
|
||||
_openx_list = yaml.safe_load(f)
|
||||
|
||||
OPENX_DATASET_CONFIGS = _openx_list["OPENX_DATASET_CONFIGS"]
|
||||
|
||||
np.set_printoptions(precision=2)
|
||||
|
||||
|
||||
@@ -104,6 +108,7 @@ def load_from_raw(
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
openx_dataset_name: str | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -131,23 +136,18 @@ def load_from_raw(
|
||||
# we will apply the standardization transform if the dataset_name is provided
|
||||
# if the dataset name is not provided and the goal is to convert any rlds formatted dataset
|
||||
# search for 'image' keys in the observations
|
||||
image_keys = []
|
||||
state_keys = []
|
||||
observation_info = dataset_info.features["steps"]["observation"]
|
||||
for key in observation_info:
|
||||
# check whether the key is for an image or a vector observation
|
||||
if len(observation_info[key].shape) == 3:
|
||||
# only adding uint8 images discards depth images
|
||||
if observation_info[key].dtype == tf.uint8:
|
||||
image_keys.append(key)
|
||||
else:
|
||||
state_keys.append(key)
|
||||
if openx_dataset_name is not None:
|
||||
print(" - applying standardization transform for dataset: ", openx_dataset_name)
|
||||
assert openx_dataset_name in OPENX_STANDARDIZATION_TRANSFORMS
|
||||
transform_fn = OPENX_STANDARDIZATION_TRANSFORMS[openx_dataset_name]
|
||||
dataset = dataset.map(transform_fn)
|
||||
|
||||
lang_key = (
|
||||
"language_instruction"
|
||||
if "language_instruction" in dataset.element_spec
|
||||
else None
|
||||
)
|
||||
image_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["image_obs_keys"]
|
||||
else:
|
||||
obs_keys = dataset_info.features["steps"]["observation"].keys()
|
||||
image_keys = [key for key in obs_keys if "image" in key]
|
||||
|
||||
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
|
||||
|
||||
print(" - image_keys: ", image_keys)
|
||||
print(" - lang_key: ", lang_key)
|
||||
@@ -193,33 +193,50 @@ def load_from_raw(
|
||||
|
||||
num_frames = episode["action"].shape[0]
|
||||
|
||||
ep_dict = {}
|
||||
for key in state_keys:
|
||||
ep_dict[f"observation.{key}"] = tf_to_torch(episode["observation"][key])
|
||||
###########################################################
|
||||
# Handle the episodic data
|
||||
|
||||
ep_dict["action"] = tf_to_torch(episode["action"])
|
||||
ep_dict["next.reward"] = tf_to_torch(episode["reward"]).float()
|
||||
ep_dict["next.done"] = tf_to_torch(episode["is_last"])
|
||||
ep_dict["is_terminal"] = tf_to_torch(episode["is_terminal"])
|
||||
ep_dict["is_first"] = tf_to_torch(episode["is_first"])
|
||||
ep_dict["discount"] = tf_to_torch(episode["discount"])
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
ep_dict = {}
|
||||
langs = [] # TODO: might be located in "observation"
|
||||
|
||||
image_array_dict = {key: [] for key in image_keys}
|
||||
|
||||
# We will create the state observation tensor by stacking the state
|
||||
# obs keys defined in the openx/configs.py
|
||||
if openx_dataset_name is not None:
|
||||
state_obs_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["state_obs_keys"]
|
||||
# stack the state observations, if is None, pad with zeros
|
||||
states = []
|
||||
for key in state_obs_keys:
|
||||
if key in episode["observation"]:
|
||||
states.append(tf_to_torch(episode["observation"][key]))
|
||||
else:
|
||||
states.append(torch.zeros(num_frames, 1)) # pad with zeros
|
||||
states = torch.cat(states, dim=1)
|
||||
# assert states.shape == (num_frames, 8), f"states shape: {states.shape}"
|
||||
else:
|
||||
states = tf_to_torch(episode["observation"]["state"])
|
||||
|
||||
actions = tf_to_torch(episode["action"])
|
||||
rewards = tf_to_torch(episode["reward"]).float()
|
||||
|
||||
# If lang_key is present, convert the entire tensor at once
|
||||
if lang_key is not None:
|
||||
ep_dict["language_instruction"] = [
|
||||
x.numpy().decode("utf-8") for x in episode[lang_key]
|
||||
]
|
||||
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
|
||||
image_array_dict = {key: [] for key in image_keys}
|
||||
langs = [str(x) for x in episode[lang_key]]
|
||||
|
||||
for im_key in image_keys:
|
||||
imgs = episode["observation"][im_key]
|
||||
image_array_dict[im_key] = [tf_img_convert(img) for img in imgs]
|
||||
|
||||
# simple assertions
|
||||
for item in [states, actions, rewards, done]:
|
||||
assert len(item) == num_frames
|
||||
|
||||
###########################################################
|
||||
|
||||
# loop through all cameras
|
||||
for im_key in image_keys:
|
||||
img_key = f"observation.images.{im_key}"
|
||||
@@ -240,12 +257,22 @@ def load_from_raw(
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps}
|
||||
for i in range(num_frames)
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
if lang_key is not None:
|
||||
ep_dict["language_instruction"] = langs
|
||||
|
||||
ep_dict["observation.state"] = states
|
||||
ep_dict["action"] = actions
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["next.reward"] = rewards
|
||||
ep_dict["next.done"] = done
|
||||
|
||||
path_ep_dict = tmp_ep_dicts_dir.joinpath(
|
||||
"ep_dict_" + "0" * (10 - len(str(ep_idx))) + str(ep_idx) + ".pt"
|
||||
)
|
||||
@@ -263,30 +290,30 @@ def load_from_raw(
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features = {}
|
||||
|
||||
for key in data_dict:
|
||||
# check if vector state obs
|
||||
if key.startswith("observation.") and "observation.images." not in key:
|
||||
features[key] = Sequence(
|
||||
length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
# check if image obs
|
||||
elif "observation.images." in key:
|
||||
if video:
|
||||
features[key] = VideoFrame()
|
||||
else:
|
||||
features[key] = Image()
|
||||
keys = [key for key in data_dict if "observation.images." in key]
|
||||
for key in keys:
|
||||
if video:
|
||||
features[key] = VideoFrame()
|
||||
else:
|
||||
features[key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "language_instruction" in data_dict:
|
||||
features["language_instruction"] = Value(dtype="string", id=None)
|
||||
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
|
||||
features["is_terminal"] = Value(dtype="bool", id=None)
|
||||
features["is_first"] = Value(dtype="bool", id=None)
|
||||
features["discount"] = Value(dtype="float32", id=None)
|
||||
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
@@ -306,8 +333,19 @@ def from_raw_to_lerobot_format(
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
openx_dataset_name: str | None = None,
|
||||
):
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||
"""This is a test impl for rlds conversion"""
|
||||
if openx_dataset_name is None:
|
||||
# set a default rlds frame rate if the dataset is not from openx
|
||||
fps = 30
|
||||
elif "fps" not in OPENX_DATASET_CONFIGS[openx_dataset_name]:
|
||||
raise ValueError(
|
||||
"fps for this dataset is not specified in openx/configs.py yet," "means it is not yet tested"
|
||||
)
|
||||
fps = OPENX_DATASET_CONFIGS[openx_dataset_name]["fps"]
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding, openx_dataset_name)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
|
||||
@@ -27,12 +27,12 @@ from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
@@ -56,9 +56,7 @@ def check_format(raw_dir):
|
||||
|
||||
required_datasets.remove("meta/episode_ends")
|
||||
|
||||
assert all(
|
||||
nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets
|
||||
)
|
||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
@@ -78,9 +76,7 @@ def load_from_raw(
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
print(
|
||||
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
|
||||
)
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
# as define in gmy-pusht env: https://github.com/huggingface/gym-pusht/blob/e0684ff988d223808c0a9dcfaba9dc4991791370/gym_pusht/envs/pusht.py#L174
|
||||
success_threshold = 0.95 # 95% coverage,
|
||||
@@ -154,9 +150,7 @@ def load_from_raw(
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body, block_shapes = PushTEnv.add_tee(
|
||||
space, block_pos[i].tolist(), block_angle[i].item()
|
||||
)
|
||||
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
@@ -165,9 +159,7 @@ def load_from_raw(
|
||||
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
||||
success[i] = coverage > success_threshold
|
||||
if keypoints_instead_of_image:
|
||||
keypoints[i] = torch.from_numpy(
|
||||
PushTEnv.get_keypoints(block_shapes).flatten()
|
||||
)
|
||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
@@ -192,8 +184,7 @@ def load_from_raw(
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps}
|
||||
for i in range(num_frames)
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
@@ -202,9 +193,7 @@ def load_from_raw(
|
||||
if keypoints_instead_of_image:
|
||||
ep_dict["observation.environment_state"] = keypoints
|
||||
ep_dict["action"] = actions[from_idx:to_idx]
|
||||
ep_dict["episode_index"] = torch.tensor(
|
||||
[ep_idx] * num_frames, dtype=torch.int64
|
||||
)
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
# ep_dict["next.observation.image"] = image[1:],
|
||||
@@ -231,8 +220,7 @@ def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if keypoints_instead_of_image:
|
||||
features["observation.environment_state"] = Sequence(
|
||||
@@ -273,9 +261,7 @@ def from_raw_to_lerobot_format(
|
||||
if fps is None:
|
||||
fps = 10
|
||||
|
||||
data_dict = load_from_raw(
|
||||
raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding
|
||||
)
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding)
|
||||
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
|
||||
@@ -26,16 +26,14 @@ from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import (
|
||||
register_codecs,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
@@ -63,9 +61,7 @@ def check_format(raw_dir) -> bool:
|
||||
nb_frames = zarr_data["data/camera0_rgb"].shape[0]
|
||||
|
||||
required_datasets.remove("meta/episode_ends")
|
||||
assert all(
|
||||
nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets
|
||||
)
|
||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
@@ -83,9 +79,7 @@ def load_from_raw(
|
||||
end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
|
||||
start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
|
||||
eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
|
||||
eff_rot_axis_angle = torch.from_numpy(
|
||||
zarr_data["data/robot0_eef_rot_axis_angle"][:]
|
||||
)
|
||||
eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
|
||||
gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
|
||||
|
||||
states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
|
||||
@@ -135,31 +129,24 @@ def load_from_raw(
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
encode_video_frames(
|
||||
tmp_imgs_dir, video_path, fps, **(encoding or {})
|
||||
)
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps}
|
||||
for i in range(num_frames)
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["episode_index"] = torch.tensor(
|
||||
[ep_idx] * num_frames, dtype=torch.int64
|
||||
)
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor(
|
||||
[from_idx + num_frames] * num_frames
|
||||
)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
|
||||
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
|
||||
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
|
||||
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
|
||||
@@ -185,8 +172,7 @@ def to_hf_dataset(data_dict, video):
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
@@ -206,8 +192,7 @@ def to_hf_dataset(data_dict, video):
|
||||
length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["gripper_width"] = Sequence(
|
||||
length=data_dict["gripper_width"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
|
||||
@@ -16,9 +16,7 @@
|
||||
import inspect
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import datasets
|
||||
import numpy
|
||||
import PIL
|
||||
import torch
|
||||
@@ -45,9 +43,7 @@ def concatenate_episodes(ep_dicts):
|
||||
return data_dict
|
||||
|
||||
|
||||
def save_images_concurrently(
|
||||
imgs_array: numpy.array, out_dir: Path, max_workers: int = 4
|
||||
):
|
||||
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -57,10 +53,7 @@ def save_images_concurrently(
|
||||
|
||||
num_images = len(imgs_array)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[
|
||||
executor.submit(save_image, imgs_array[i], i, out_dir)
|
||||
for i in range(num_images)
|
||||
]
|
||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
|
||||
|
||||
def get_default_encoding() -> dict:
|
||||
@@ -69,8 +62,7 @@ def get_default_encoding() -> dict:
|
||||
return {
|
||||
k: v.default
|
||||
for k, v in signature.parameters.items()
|
||||
if v.default is not inspect.Parameter.empty
|
||||
and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
}
|
||||
|
||||
|
||||
@@ -80,60 +72,3 @@ def check_repo_id(repo_id: str) -> None:
|
||||
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
|
||||
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
|
||||
)
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def calculate_episode_data_index(
|
||||
hf_dataset: datasets.Dataset,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||
|
||||
Parameters:
|
||||
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
|
||||
|
||||
Returns:
|
||||
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
|
||||
- "from": A tensor containing the starting index of each episode.
|
||||
- "to": A tensor containing the ending index of each episode.
|
||||
"""
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
current_episode = None
|
||||
"""
|
||||
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||
For instance, the following is a valid episode_index:
|
||||
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||
|
||||
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||
{
|
||||
"from": [0, 3, 7],
|
||||
"to": [3, 7, 12]
|
||||
}
|
||||
"""
|
||||
if len(hf_dataset) == 0:
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([]),
|
||||
"to": torch.tensor([]),
|
||||
}
|
||||
return episode_data_index
|
||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||
if episode_idx != current_episode:
|
||||
# We encountered a new episode, so we append its starting location to the "from" list
|
||||
episode_data_index["from"].append(idx)
|
||||
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
|
||||
if current_episode is not None:
|
||||
episode_data_index["to"].append(idx)
|
||||
# Let's keep track of the current episode index
|
||||
current_episode = episode_idx
|
||||
else:
|
||||
# We are still in the same episode, so there is nothing for us to do here
|
||||
pass
|
||||
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
|
||||
episode_data_index["to"].append(idx + 1)
|
||||
|
||||
for k in ["from", "to"]:
|
||||
episode_data_index[k] = torch.tensor(episode_data_index[k])
|
||||
|
||||
return episode_data_index
|
||||
|
||||
@@ -27,12 +27,12 @@ from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
@@ -40,10 +40,7 @@ from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
def check_format(raw_dir):
|
||||
keys = {"actions", "rewards", "dones"}
|
||||
nested_keys = {
|
||||
"observations": {"rgb", "state"},
|
||||
"next_observations": {"rgb", "state"},
|
||||
}
|
||||
nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}}
|
||||
|
||||
xarm_files = list(raw_dir.glob("*.pkl"))
|
||||
assert len(xarm_files) > 0
|
||||
@@ -56,17 +53,11 @@ def check_format(raw_dir):
|
||||
|
||||
# Check for consistent lengths in nested keys
|
||||
expected_len = len(dataset_dict["actions"])
|
||||
assert all(
|
||||
len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict
|
||||
)
|
||||
assert all(len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict)
|
||||
|
||||
for key, subkeys in nested_keys.items():
|
||||
nested_dict = dataset_dict.get(key, {})
|
||||
assert all(
|
||||
len(nested_dict[subkey]) == expected_len
|
||||
for subkey in subkeys
|
||||
if subkey in nested_dict
|
||||
)
|
||||
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
@@ -131,18 +122,13 @@ def load_from_raw(
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps}
|
||||
for i in range(num_frames)
|
||||
]
|
||||
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["action"] = action
|
||||
ep_dict["episode_index"] = torch.tensor(
|
||||
[ep_idx] * num_frames, dtype=torch.int64
|
||||
)
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
# ep_dict["next.observation.image"] = next_image
|
||||
@@ -167,8 +153,7 @@ def to_hf_dataset(data_dict, video):
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
|
||||
@@ -43,10 +43,7 @@ class EpisodeAwareSampler:
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
indices.extend(
|
||||
range(
|
||||
start_index.item() + drop_n_first_frames,
|
||||
end_index.item() - drop_n_last_frames,
|
||||
)
|
||||
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
|
||||
)
|
||||
|
||||
self.indices = indices
|
||||
|
||||
@@ -57,9 +57,7 @@ class RandomSubsetApply(Transform):
|
||||
elif not isinstance(n_subset, int):
|
||||
raise TypeError("n_subset should be an int or None")
|
||||
elif not (1 <= n_subset <= len(transforms)):
|
||||
raise ValueError(
|
||||
f"n_subset should be in the interval [1, {len(transforms)}]"
|
||||
)
|
||||
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
|
||||
|
||||
self.transforms = transforms
|
||||
total = sum(p)
|
||||
@@ -118,22 +116,16 @@ class SharpnessJitter(Transform):
|
||||
def _check_input(self, sharpness):
|
||||
if isinstance(sharpness, (int, float)):
|
||||
if sharpness < 0:
|
||||
raise ValueError(
|
||||
"If sharpness is a single number, it must be non negative."
|
||||
)
|
||||
raise ValueError("If sharpness is a single number, it must be non negative.")
|
||||
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
||||
sharpness[0] = max(sharpness[0], 0.0)
|
||||
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
|
||||
sharpness = [float(v) for v in sharpness]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"{sharpness=} should be a single number or a sequence with length 2."
|
||||
)
|
||||
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
|
||||
|
||||
if not 0.0 <= sharpness[0] <= sharpness[1]:
|
||||
raise ValueError(
|
||||
f"sharpnesss values should be between (0., inf), but got {sharpness}."
|
||||
)
|
||||
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
|
||||
|
||||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
@@ -142,9 +134,7 @@ class SharpnessJitter(Transform):
|
||||
|
||||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
||||
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
|
||||
return self._call_kernel(
|
||||
F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor
|
||||
)
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
def get_image_transforms(
|
||||
@@ -160,10 +150,6 @@ def get_image_transforms(
|
||||
sharpness_min_max: tuple[float, float] | None = None,
|
||||
max_num_transforms: int | None = None,
|
||||
random_order: bool = False,
|
||||
interpolation: str | None = None,
|
||||
image_size: tuple[int, int] | None = None,
|
||||
image_mean: list[float] | None = None,
|
||||
image_std: list[float] | None = None,
|
||||
):
|
||||
def check_value(name, weight, min_max):
|
||||
if min_max is not None:
|
||||
@@ -184,22 +170,6 @@ def get_image_transforms(
|
||||
|
||||
weights = []
|
||||
transforms = []
|
||||
if image_size is not None:
|
||||
interpolations = [interpolation.value for interpolation in v2.InterpolationMode]
|
||||
if interpolation is None:
|
||||
# Use BICUBIC as default interpolation
|
||||
interpolation_mode = v2.InterpolationMode.BICUBIC
|
||||
elif interpolation in interpolations:
|
||||
interpolation_mode = v2.InterpolationMode(interpolation)
|
||||
else:
|
||||
raise ValueError("The interpolation passed is not supported")
|
||||
# Weight for resizing is always 1
|
||||
weights.append(1.0)
|
||||
transforms.append(
|
||||
v2.Resize(
|
||||
size=(image_size[0], image_size[1]), interpolation=interpolation_mode
|
||||
)
|
||||
)
|
||||
if brightness_min_max is not None and brightness_weight > 0.0:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
@@ -215,15 +185,6 @@ def get_image_transforms(
|
||||
if sharpness_min_max is not None and sharpness_weight > 0.0:
|
||||
weights.append(sharpness_weight)
|
||||
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
|
||||
if image_mean is not None and image_std is not None:
|
||||
# Weight for normalization is always 1
|
||||
weights.append(1.0)
|
||||
transforms.append(
|
||||
v2.Normalize(
|
||||
mean=image_mean,
|
||||
std=image_std,
|
||||
)
|
||||
)
|
||||
|
||||
n_subset = len(transforms)
|
||||
if max_num_transforms is not None:
|
||||
@@ -233,6 +194,4 @@ def get_image_transforms(
|
||||
return v2.Identity()
|
||||
else:
|
||||
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||
return RandomSubsetApply(
|
||||
transforms, p=weights, n_subset=n_subset, random_order=random_order
|
||||
)
|
||||
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
|
||||
|
||||
@@ -13,66 +13,31 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
from collections.abc import Iterator
|
||||
from itertools import accumulate
|
||||
import re
|
||||
import warnings
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import pyarrow.compute as pc
|
||||
import torch
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from datasets import load_dataset, load_from_disk
|
||||
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
|
||||
from PIL import Image as PILImage
|
||||
from safetensors.torch import load_file
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
EPISODES_PATH = "meta/episodes.jsonl"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = (
|
||||
"videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
)
|
||||
DEFAULT_PARQUET_PATH = (
|
||||
"data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
)
|
||||
DEFAULT_IMAGE_PATH = (
|
||||
"images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
)
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
# Metadata will go there
|
||||
---
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
## {}
|
||||
|
||||
"""
|
||||
|
||||
DEFAULT_FEATURES = {
|
||||
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
def flatten_dict(d, parent_key="", sep="/"):
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
|
||||
For example:
|
||||
@@ -91,7 +56,7 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
return dict(items)
|
||||
|
||||
|
||||
def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
def unflatten_dict(d, sep="/"):
|
||||
outdict = {}
|
||||
for key, value in d.items():
|
||||
parts = key.split(sep)
|
||||
@@ -104,89 +69,6 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
return outdict
|
||||
|
||||
|
||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
serialized_dict = {
|
||||
key: value.tolist() for key, value in flatten_dict(stats).items()
|
||||
}
|
||||
return unflatten_dict(serialized_dict)
|
||||
|
||||
|
||||
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = dataset.format
|
||||
dataset = dataset.with_format("arrow")
|
||||
dataset = dataset.map(embed_table_storage, batched=False)
|
||||
dataset = dataset.with_format(**format)
|
||||
dataset.to_parquet(fpath)
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> Any:
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def load_jsonlines(fpath: Path) -> list[Any]:
|
||||
with jsonlines.open(fpath, "r") as reader:
|
||||
return list(reader)
|
||||
|
||||
|
||||
def write_jsonlines(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(data)
|
||||
|
||||
|
||||
def append_jsonlines(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "a") as writer:
|
||||
writer.write(data)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
return info
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict:
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
return {
|
||||
item["task_index"]: item["task"]
|
||||
for item in sorted(tasks, key=lambda x: x["task_index"])
|
||||
}
|
||||
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
return load_jsonlines(local_dir / EPISODES_PATH)
|
||||
|
||||
|
||||
def load_image_as_numpy(
|
||||
fpath: str | Path, dtype="float32", channel_first: bool = True
|
||||
) -> np.ndarray:
|
||||
img = PILImage.open(fpath).convert("RGB")
|
||||
img_array = np.array(img, dtype=dtype)
|
||||
if channel_first: # (H, W, C) -> (C, H, W)
|
||||
img_array = np.transpose(img_array, (2, 0, 1))
|
||||
if "float" in dtype:
|
||||
img_array /= 255.0
|
||||
return img_array
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
@@ -198,6 +80,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif isinstance(first_item, str):
|
||||
# TODO (michel-aractingi): add str2embedding via language tokenizer
|
||||
# For now we leave this part up to the user to choose how to address
|
||||
# language conditioned tasks
|
||||
pass
|
||||
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
|
||||
# video frame will be processed downstream
|
||||
pass
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
@@ -205,70 +95,19 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
return items_dict
|
||||
|
||||
|
||||
def _get_major_minor(version: str) -> tuple[int]:
|
||||
split = version.strip("v").split(".")
|
||||
return int(split[0]), int(split[1])
|
||||
|
||||
|
||||
class BackwardCompatibilityError(Exception):
|
||||
def __init__(self, repo_id, version):
|
||||
message = textwrap.dedent(f"""
|
||||
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
|
||||
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
|
||||
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
|
||||
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
""")
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def check_version_compatibility(
|
||||
repo_id: str,
|
||||
version_to_check: str,
|
||||
current_version: str,
|
||||
enforce_breaking_major: bool = True,
|
||||
) -> None:
|
||||
current_major, _ = _get_major_minor(current_version)
|
||||
major_to_check, _ = _get_major_minor(version_to_check)
|
||||
if major_to_check < current_major and enforce_breaking_major:
|
||||
raise BackwardCompatibilityError(repo_id, version_to_check)
|
||||
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
||||
logging.warning(
|
||||
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
||||
codebase. The current codebase version is {current_version}. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
)
|
||||
|
||||
|
||||
def get_hub_safe_version(repo_id: str, version: str) -> str:
|
||||
@cache
|
||||
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
|
||||
api = HfApi()
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if version not in branches:
|
||||
num_version = float(version.strip("v"))
|
||||
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
|
||||
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
|
||||
raise BackwardCompatibilityError(repo_id, version)
|
||||
|
||||
logging.warning(
|
||||
warnings.warn(
|
||||
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
|
||||
codebase. The following versions are available: {branches}.
|
||||
The requested version ('{version}') is not found. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
stacklevel=1,
|
||||
)
|
||||
if "main" not in branches:
|
||||
raise ValueError(f"Version 'main' not found on {repo_id}")
|
||||
@@ -277,200 +116,275 @@ def get_hub_safe_version(repo_id: str, version: str) -> str:
|
||||
return version
|
||||
|
||||
|
||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
elif ft["shape"] == (1,):
|
||||
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
||||
else:
|
||||
assert len(ft["shape"]) == 1
|
||||
hf_features[key] = datasets.Sequence(
|
||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||
)
|
||||
# TODO: (alibers, azouitine) Add support for ft["shap"] == 0 as Value
|
||||
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if root is not None:
|
||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
||||
# TODO(rcadene): clean this which enables getting a subset of dataset
|
||||
if split != "train":
|
||||
if "%" in split:
|
||||
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
|
||||
match_from = re.search(r"train\[(\d+):\]", split)
|
||||
match_to = re.search(r"train\[:(\d+)\]", split)
|
||||
if match_from:
|
||||
from_frame_index = int(match_from.group(1))
|
||||
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
|
||||
elif match_to:
|
||||
to_frame_index = int(match_to.group(1))
|
||||
hf_dataset = hf_dataset.select(range(to_frame_index))
|
||||
else:
|
||||
raise ValueError(
|
||||
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
|
||||
)
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
|
||||
|
||||
return datasets.Features(hf_features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||
camera_ft = {}
|
||||
if robot.cameras:
|
||||
camera_ft = {
|
||||
key: {"dtype": "video" if use_videos else "image", **ft}
|
||||
for key, ft in robot.camera_features.items()
|
||||
}
|
||||
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
|
||||
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
|
||||
"""episode_data_index contains the range of indices for each episode
|
||||
|
||||
Example:
|
||||
```python
|
||||
from_id = episode_data_index["from"][episode_id].item()
|
||||
to_id = episode_data_index["to"][episode_id].item()
|
||||
episode_frames = [dataset[i] for i in range(from_id, to_id)]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
|
||||
)
|
||||
|
||||
return load_file(path)
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
robot_type: str,
|
||||
features: dict,
|
||||
use_videos: bool,
|
||||
) -> dict:
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": 0,
|
||||
"total_frames": 0,
|
||||
"total_tasks": 0,
|
||||
"total_videos": 0,
|
||||
"total_chunks": 0,
|
||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
||||
|
||||
Example:
|
||||
```python
|
||||
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
|
||||
)
|
||||
|
||||
stats = load_file(path)
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def get_episode_data_index(
|
||||
episode_dicts: list[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {
|
||||
ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)
|
||||
}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
def load_info(repo_id, version, root) -> dict:
|
||||
"""info contains useful information regarding the dataset that are not stored elsewhere
|
||||
|
||||
cumulative_lenghts = list(accumulate(episode_lengths.values()))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
}
|
||||
Example:
|
||||
```python
|
||||
print("frame per second used to collect the video", info["fps"])
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "info.json"
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
|
||||
|
||||
with open(path) as f:
|
||||
info = json.load(f)
|
||||
return info
|
||||
|
||||
|
||||
def calculate_total_episode(
|
||||
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_indices = sorted(hf_dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
|
||||
raise ValueError("episode_index values are not sorted and contiguous.")
|
||||
return total_episodes
|
||||
def load_videos(repo_id, version, root) -> Path:
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "videos"
|
||||
else:
|
||||
# TODO(rcadene): we download the whole repo here. see if we can avoid this
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
|
||||
path = Path(repo_dir) / "videos"
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def calculate_episode_data_index(
|
||||
hf_dataset: datasets.Dataset,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = []
|
||||
table = hf_dataset.data.table
|
||||
total_episodes = calculate_total_episode(hf_dataset)
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
|
||||
cumulative_lenghts = list(accumulate(episode_lengths))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
}
|
||||
|
||||
|
||||
def check_timestamps_sync(
|
||||
def load_previous_and_future_frames(
|
||||
item: dict[str, torch.Tensor],
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
fps: int,
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
|
||||
account for possible numerical error.
|
||||
"""
|
||||
timestamps = torch.stack(hf_dataset["timestamp"])
|
||||
diffs = torch.diff(timestamps)
|
||||
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
|
||||
|
||||
# We mask differences between the timestamp at the end of an episode
|
||||
# and the one at the start of the next episode since these are expected
|
||||
# to be outside tolerance.
|
||||
mask = torch.ones(len(diffs), dtype=torch.bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1
|
||||
mask[ignored_diffs] = False
|
||||
filtered_within_tolerance = within_tolerance[mask]
|
||||
|
||||
if not torch.all(filtered_within_tolerance):
|
||||
# Track original indices before masking
|
||||
original_indices = torch.arange(len(diffs))
|
||||
filtered_indices = original_indices[mask]
|
||||
outside_tolerance_filtered_indices = torch.nonzero(
|
||||
~filtered_within_tolerance
|
||||
) # .squeeze()
|
||||
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"])
|
||||
|
||||
outside_tolerances = []
|
||||
for idx in outside_tolerance_indices:
|
||||
entry = {
|
||||
"timestamps": [timestamps[idx], timestamps[idx + 1]],
|
||||
"diff": diffs[idx],
|
||||
"episode_index": episode_indices[idx].item(),
|
||||
}
|
||||
outside_tolerances.append(entry)
|
||||
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
|
||||
This might be due to synchronization issues with timestamps during data collection.
|
||||
\n{pformat(outside_tolerances)}"""
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
fps: int,
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
) -> bool:
|
||||
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
|
||||
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
|
||||
actual timestamps from the dataset.
|
||||
) -> dict[torch.Tensor]:
|
||||
"""
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
within_tolerance = [
|
||||
abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts
|
||||
]
|
||||
if not all(within_tolerance):
|
||||
outside_tolerance[key] = [
|
||||
ts
|
||||
for ts, is_within in zip(delta_ts, within_tolerance, strict=True)
|
||||
if not is_within
|
||||
]
|
||||
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
|
||||
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
|
||||
given modality (e.g. "observation.image") a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest
|
||||
frames in the dataset.
|
||||
|
||||
if len(outside_tolerance) > 0:
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""
|
||||
The following delta_timestamps are found outside of tolerance range.
|
||||
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
|
||||
their values accordingly.
|
||||
\n{pformat(outside_tolerance)}
|
||||
"""
|
||||
)
|
||||
return False
|
||||
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
|
||||
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
|
||||
the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function
|
||||
populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array
|
||||
is useful during batched training to not supervise actions associated to timestamps coming after the end of the
|
||||
episode, or to pad the observations in a specific way. Note that by default the observation frames before the start
|
||||
of the episode are the same as the first frame of the episode.
|
||||
|
||||
return True
|
||||
Parameters:
|
||||
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key
|
||||
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
|
||||
modality (e.g., "timestamp", "observation.image", "action").
|
||||
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
||||
They indicate the start index and end index of each episode in the dataset.
|
||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
|
||||
retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
||||
- tolerance_s (float, optional): The tolerance level (in seconds) used to determine if a data point is close enough to the query
|
||||
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
|
||||
smallest expected inter-frame period, but large enough to account for jitter.
|
||||
|
||||
Returns:
|
||||
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for
|
||||
each modality (e.g. "observation.image_is_pad").
|
||||
|
||||
Raises:
|
||||
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization
|
||||
issues with timestamps during data collection.
|
||||
"""
|
||||
# get indices of the frames associated to the episode, and their timestamps
|
||||
ep_id = item["episode_index"].item()
|
||||
ep_data_id_from = episode_data_index["from"][ep_id].item()
|
||||
ep_data_id_to = episode_data_index["to"][ep_id].item()
|
||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
||||
|
||||
# load timestamps
|
||||
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
||||
ep_timestamps = torch.stack(ep_timestamps)
|
||||
|
||||
# we make the assumption that the timestamps are sorted
|
||||
ep_first_ts = ep_timestamps[0]
|
||||
ep_last_ts = ep_timestamps[-1]
|
||||
current_ts = item["timestamp"].item()
|
||||
|
||||
for key in delta_timestamps:
|
||||
# get timestamps used as query to retrieve data of previous/future frames
|
||||
delta_ts = delta_timestamps[key]
|
||||
query_ts = current_ts + torch.tensor(delta_ts)
|
||||
|
||||
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
||||
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
||||
|
||||
is_pad = min_ > tolerance_s
|
||||
|
||||
# check violated query timestamps are all outside the episode range
|
||||
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
)
|
||||
|
||||
# get dataset indices corresponding to frames to be loaded
|
||||
data_ids = ep_data_ids[argmin_]
|
||||
|
||||
# load frames modality
|
||||
item[key] = hf_dataset.select_columns(key)[data_ids][key]
|
||||
|
||||
if isinstance(item[key][0], dict) and "path" in item[key][0]:
|
||||
# video mode where frame are expressed as dict of path and timestamp
|
||||
item[key] = item[key]
|
||||
else:
|
||||
item[key] = torch.stack(item[key])
|
||||
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
|
||||
return item
|
||||
|
||||
|
||||
def get_delta_indices(
|
||||
delta_timestamps: dict[str, list[float]], fps: int
|
||||
) -> dict[str, list[int]]:
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||
|
||||
return delta_indices
|
||||
Parameters:
|
||||
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
|
||||
|
||||
Returns:
|
||||
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
|
||||
- "from": A tensor containing the starting index of each episode.
|
||||
- "to": A tensor containing the ending index of each episode.
|
||||
"""
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
current_episode = None
|
||||
"""
|
||||
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||
For instance, the following is a valid episode_index:
|
||||
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||
|
||||
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||
{
|
||||
"from": [0, 3, 7],
|
||||
"to": [3, 7, 12]
|
||||
}
|
||||
"""
|
||||
if len(hf_dataset) == 0:
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([]),
|
||||
"to": torch.tensor([]),
|
||||
}
|
||||
return episode_data_index
|
||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||
if episode_idx != current_episode:
|
||||
# We encountered a new episode, so we append its starting location to the "from" list
|
||||
episode_data_index["from"].append(idx)
|
||||
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
|
||||
if current_episode is not None:
|
||||
episode_data_index["to"].append(idx)
|
||||
# Let's keep track of the current episode index
|
||||
current_episode = episode_idx
|
||||
else:
|
||||
# We are still in the same episode, so there is nothing for us to do here
|
||||
pass
|
||||
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
|
||||
episode_data_index["to"].append(idx + 1)
|
||||
|
||||
for k in ["from", "to"]:
|
||||
episode_data_index[k] = torch.tensor(episode_data_index[k])
|
||||
|
||||
return episode_data_index
|
||||
|
||||
|
||||
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
"""Reset the `episode_index` of the provided HuggingFace Dataset.
|
||||
|
||||
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
|
||||
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
|
||||
|
||||
This brings the `episode_index` to the required format.
|
||||
"""
|
||||
if len(hf_dataset) == 0:
|
||||
return hf_dataset
|
||||
unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist()
|
||||
episode_idx_to_reset_idx_mapping = {
|
||||
ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs)
|
||||
}
|
||||
|
||||
def modify_ep_idx_func(example):
|
||||
example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()]
|
||||
return example
|
||||
|
||||
hf_dataset = hf_dataset.map(modify_ep_idx_func)
|
||||
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def cycle(iterable):
|
||||
@@ -486,7 +400,7 @@ def cycle(iterable):
|
||||
iterator = iter(iterable)
|
||||
|
||||
|
||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
|
||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None):
|
||||
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
||||
exists before creating it.
|
||||
"""
|
||||
@@ -501,96 +415,12 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None
|
||||
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
|
||||
def create_lerobot_dataset_card(
|
||||
tags: list | None = None,
|
||||
dataset_info: dict | None = None,
|
||||
**kwargs,
|
||||
) -> DatasetCard:
|
||||
"""
|
||||
Keyword arguments will be used to replace values in ./lerobot/common/datasets/card_template.md.
|
||||
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
|
||||
"""
|
||||
card_tags = ["LeRobot"]
|
||||
|
||||
if tags:
|
||||
card_tags += tags
|
||||
if dataset_info:
|
||||
dataset_structure = "[meta/info.json](meta/info.json):\n"
|
||||
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
|
||||
kwargs = {**kwargs, "dataset_structure": dataset_structure}
|
||||
card_data = DatasetCardData(
|
||||
license=kwargs.get("license"),
|
||||
tags=card_tags,
|
||||
task_categories=["robotics"],
|
||||
configs=[
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
card_template = (
|
||||
importlib.resources.files("lerobot.common.datasets") / "card_template.md"
|
||||
).read_text()
|
||||
|
||||
return DatasetCard.from_template(
|
||||
card_data=card_data,
|
||||
template_str=card_template,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class IterableNamespace(SimpleNamespace):
|
||||
"""
|
||||
A namespace object that supports both dictionary-like iteration and dot notation access.
|
||||
Automatically converts nested dictionaries into IterableNamespaces.
|
||||
|
||||
This class extends SimpleNamespace to provide:
|
||||
- Dictionary-style iteration over keys
|
||||
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
|
||||
- Dictionary-like methods: items(), keys(), values()
|
||||
- Recursive conversion of nested dictionaries
|
||||
|
||||
Args:
|
||||
dictionary: Optional dictionary to initialize the namespace
|
||||
**kwargs: Additional keyword arguments passed to SimpleNamespace
|
||||
|
||||
Examples:
|
||||
>>> data = {"name": "Alice", "details": {"age": 25}}
|
||||
>>> ns = IterableNamespace(data)
|
||||
>>> ns.name
|
||||
'Alice'
|
||||
>>> ns.details.age
|
||||
25
|
||||
>>> list(ns.keys())
|
||||
['name', 'details']
|
||||
>>> for key, value in ns.items():
|
||||
... print(f"{key}: {value}")
|
||||
name: Alice
|
||||
details: IterableNamespace(age=25)
|
||||
"""
|
||||
|
||||
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if dictionary is not None:
|
||||
for key, value in dictionary.items():
|
||||
if isinstance(value, dict):
|
||||
setattr(self, key, IterableNamespace(value))
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(vars(self))
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return vars(self)[key]
|
||||
|
||||
def items(self):
|
||||
return vars(self).items()
|
||||
|
||||
def values(self):
|
||||
return vars(self).values()
|
||||
|
||||
def keys(self):
|
||||
return vars(self).keys()
|
||||
def create_lerobot_dataset_card(tags: list | None = None, text: str | None = None) -> DatasetCard:
|
||||
card = DatasetCard(DATASET_CARD_TEMPLATE)
|
||||
card.data.task_categories = ["robotics"]
|
||||
card.data.tags = ["LeRobot"]
|
||||
if tags is not None:
|
||||
card.data.tags += tags
|
||||
if text is not None:
|
||||
card.text += text
|
||||
return card
|
||||
|
||||
@@ -1,924 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.
|
||||
|
||||
Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in
|
||||
lerobot/configs/robot/aloha.yaml before running this script.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import (
|
||||
convert_dataset,
|
||||
parse_robot_config,
|
||||
)
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
ALOHA_CONFIG = Path("lerobot/configs/robot/aloha.yaml")
|
||||
ALOHA_MOBILE_INFO = {
|
||||
"robot_config": parse_robot_config(ALOHA_CONFIG),
|
||||
"license": "mit",
|
||||
"url": "https://mobile-aloha.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2401.02117",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{fu2024mobile,
|
||||
author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea},
|
||||
title = {Mobile ALOHA: Learning Bimanual Mobile Manipulation with Low-Cost Whole-Body Teleoperation},
|
||||
booktitle = {arXiv},
|
||||
year = {2024},
|
||||
}""").lstrip(),
|
||||
}
|
||||
ALOHA_STATIC_INFO = {
|
||||
"robot_config": parse_robot_config(ALOHA_CONFIG),
|
||||
"license": "mit",
|
||||
"url": "https://tonyzhaozh.github.io/aloha/",
|
||||
"paper": "https://arxiv.org/abs/2304.13705",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{Zhao2023LearningFB,
|
||||
title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
|
||||
author={Tony Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
|
||||
journal={RSS},
|
||||
year={2023},
|
||||
volume={abs/2304.13705},
|
||||
url={https://arxiv.org/abs/2304.13705}
|
||||
}""").lstrip(),
|
||||
}
|
||||
PUSHT_INFO = {
|
||||
"license": "mit",
|
||||
"url": "https://diffusion-policy.cs.columbia.edu/",
|
||||
"paper": "https://arxiv.org/abs/2303.04137v5",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{chi2024diffusionpolicy,
|
||||
author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
|
||||
title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
|
||||
journal = {The International Journal of Robotics Research},
|
||||
year = {2024},
|
||||
}""").lstrip(),
|
||||
}
|
||||
XARM_INFO = {
|
||||
"license": "mit",
|
||||
"url": "https://www.nicklashansen.com/td-mpc/",
|
||||
"paper": "https://arxiv.org/abs/2203.04955",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{Hansen2022tdmpc,
|
||||
title={Temporal Difference Learning for Model Predictive Control},
|
||||
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
|
||||
booktitle={ICML},
|
||||
year={2022}
|
||||
}
|
||||
"""),
|
||||
}
|
||||
UNITREEH_INFO = {
|
||||
"license": "apache-2.0",
|
||||
}
|
||||
|
||||
DATASETS = {
|
||||
"aloha_mobile_cabinet": {
|
||||
"single_task": "Open the top cabinet, store the pot inside it then close the cabinet.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_chair": {
|
||||
"single_task": "Push the chairs in front of the desk to place them against it.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_elevator": {
|
||||
"single_task": "Take the elevator to the 1st floor.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_shrimp": {
|
||||
"single_task": "Sauté the raw shrimp on both sides, then serve it in the bowl.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_wash_pan": {
|
||||
"single_task": "Pick up the pan, rinse it in the sink and then place it in the drying rack.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_wipe_wine": {
|
||||
"single_task": "Pick up the wet cloth on the faucet and use it to clean the spilled wine on the table and underneath the glass.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_static_battery": {
|
||||
"single_task": "Place the battery into the slot of the remote controller.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_candy": {
|
||||
"single_task": "Pick up the candy and unwrap it.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_coffee": {
|
||||
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_coffee_new": {
|
||||
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_cups_open": {
|
||||
"single_task": "Pick up the plastic cup and open its lid.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_fork_pick_up": {
|
||||
"single_task": "Pick up the fork and place it on the plate.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_pingpong_test": {
|
||||
"single_task": "Transfer one of the two balls in the right glass into the left glass, then transfer it back to the right glass.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_pro_pencil": {
|
||||
"single_task": "Pick up the pencil with the right arm, hand it over to the left arm then place it back onto the table.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_screw_driver": {
|
||||
"single_task": "Pick up the screwdriver with the right arm, hand it over to the left arm then place it into the cup.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_tape": {
|
||||
"single_task": "Cut a small piece of tape from the tape dispenser then place it on the cardboard box's edge.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_thread_velcro": {
|
||||
"single_task": "Pick up the velcro cable tie with the left arm, then insert the end of the velcro tie into the other end's loop with the right arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_towel": {
|
||||
"single_task": "Pick up a piece of paper towel and place it on the spilled liquid.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_vinh_cup": {
|
||||
"single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_vinh_cup_left": {
|
||||
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_ziploc_slide": {
|
||||
"single_task": "Slide open the ziploc bag.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_scripted": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_scripted_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_human": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_human_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_scripted": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_scripted_image": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_human": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_human_image": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"pusht": {
|
||||
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
||||
**PUSHT_INFO,
|
||||
},
|
||||
"pusht_image": {
|
||||
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
||||
**PUSHT_INFO,
|
||||
},
|
||||
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
|
||||
"unitreeh1_rearrange_objects": {
|
||||
"single_task": "Put the object into the bin.",
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"unitreeh1_two_robot_greeting": {
|
||||
"single_task": "Greet the other robot with a high five.",
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"unitreeh1_warehouse": {
|
||||
"single_task": "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.",
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_image": {
|
||||
"single_task": "Pick up the cube and lift it.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_lift_medium_replay": {
|
||||
"single_task": "Pick up the cube and lift it.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_lift_medium_replay_image": {
|
||||
"single_task": "Pick up the cube and lift it.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_image": {
|
||||
"single_task": "Push the cube onto the target.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_push_medium_replay": {
|
||||
"single_task": "Push the cube onto the target.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_push_medium_replay_image": {
|
||||
"single_task": "Push the cube onto the target.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"umi_cup_in_the_wild": {
|
||||
"single_task": "Put the cup on the plate.",
|
||||
"license": "apache-2.0",
|
||||
},
|
||||
"asu_table_top": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://link.springer.com/article/10.1007/s10514-023-10129-1",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{zhou2023modularity,
|
||||
title={Modularity through Attention: Efficient Training and Transfer of Language-Conditioned Policies for Robot Manipulation},
|
||||
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Stepputtis, Simon and Amor, Heni},
|
||||
booktitle={Conference on Robot Learning},
|
||||
pages={1684--1695},
|
||||
year={2023},
|
||||
organization={PMLR}
|
||||
}
|
||||
@article{zhou2023learning,
|
||||
title={Learning modular language-conditioned robot policies through attention},
|
||||
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Ben Amor, Heni and Stepputtis, Simon},
|
||||
journal={Autonomous Robots},
|
||||
pages={1--21},
|
||||
year={2023},
|
||||
publisher={Springer}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"austin_buds_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/BUDS-website/",
|
||||
"paper": "https://arxiv.org/abs/2109.13841",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{zhu2022bottom,
|
||||
title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation},
|
||||
author={Zhu, Yifeng and Stone, Peter and Zhu, Yuke},
|
||||
journal={IEEE Robotics and Automation Letters},
|
||||
volume={7},
|
||||
number={2},
|
||||
pages={4126--4133},
|
||||
year={2022},
|
||||
publisher={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"austin_sailor_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/sailor/",
|
||||
"paper": "https://arxiv.org/abs/2210.11435",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{nasiriany2022sailor,
|
||||
title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning},
|
||||
author={Soroush Nasiriany and Tian Gao and Ajay Mandlekar and Yuke Zhu},
|
||||
booktitle={Conference on Robot Learning (CoRL)},
|
||||
year={2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"austin_sirius_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/sirius/",
|
||||
"paper": "https://arxiv.org/abs/2211.08416",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{liu2022robot,
|
||||
title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment},
|
||||
author = {Huihan Liu and Soroush Nasiriany and Lance Zhang and Zhiyao Bao and Yuke Zhu},
|
||||
booktitle = {Robotics: Science and Systems (RSS)},
|
||||
year = {2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_autolab_ur5": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://sites.google.com/view/berkeley-ur5/home",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{BerkeleyUR5Website,
|
||||
title = {Berkeley {UR5} Demonstration Dataset},
|
||||
author = {Lawrence Yunliang Chen and Simeon Adebola and Ken Goldberg},
|
||||
howpublished = {https://sites.google.com/view/berkeley-ur5/home},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_cable_routing": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://sites.google.com/view/cablerouting/home",
|
||||
"paper": "https://arxiv.org/abs/2307.08927",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{luo2023multistage,
|
||||
author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine},
|
||||
title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning},
|
||||
journal = {arXiv pre-print},
|
||||
year = {2023},
|
||||
url = {https://arxiv.org/abs/2307.08927},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_fanuc_manipulation": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/berkeley.edu/fanuc-manipulation",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{fanuc_manipulation2023,
|
||||
title={Fanuc Manipulation: A Dataset for Learning-based Manipulation with FANUC Mate 200iD Robot},
|
||||
author={Zhu, Xinghao and Tian, Ran and Xu, Chenfeng and Ding, Mingyu and Zhan, Wei and Tomizuka, Masayoshi},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_gnm_cory_hall": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://arxiv.org/abs/1709.10489",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{kahn2018self,
|
||||
title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation},
|
||||
author={Kahn, Gregory and Villaflor, Adam and Ding, Bosen and Abbeel, Pieter and Levine, Sergey},
|
||||
booktitle={2018 IEEE international conference on robotics and automation (ICRA)},
|
||||
pages={5129--5136},
|
||||
year={2018},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_gnm_recon": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/recon-robot",
|
||||
"paper": "https://arxiv.org/abs/2104.05859",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{shah2021rapid,
|
||||
title={Rapid Exploration for Open-World Navigation with Latent Goal Models},
|
||||
author={Dhruv Shah and Benjamin Eysenbach and Nicholas Rhinehart and Sergey Levine},
|
||||
booktitle={5th Annual Conference on Robot Learning },
|
||||
year={2021},
|
||||
url={https://openreview.net/forum?id=d_SWJhyKfVw}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_gnm_sac_son": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/SACSoN-review",
|
||||
"paper": "https://arxiv.org/abs/2306.01874",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{hirose2023sacson,
|
||||
title={SACSoN: Scalable Autonomous Data Collection for Social Navigation},
|
||||
author={Hirose, Noriaki and Shah, Dhruv and Sridhar, Ajay and Levine, Sergey},
|
||||
journal={arXiv preprint arXiv:2306.01874},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_mvp": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://arxiv.org/abs/2203.06173",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@InProceedings{Radosavovic2022,
|
||||
title = {Real-World Robot Learning with Masked Visual Pre-training},
|
||||
author = {Ilija Radosavovic and Tete Xiao and Stephen James and Pieter Abbeel and Jitendra Malik and Trevor Darrell},
|
||||
booktitle = {CoRL},
|
||||
year = {2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_rpt": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://arxiv.org/abs/2306.10007",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{Radosavovic2023,
|
||||
title={Robot Learning with Sensorimotor Pre-training},
|
||||
author={Ilija Radosavovic and Baifeng Shi and Letian Fu and Ken Goldberg and Trevor Darrell and Jitendra Malik},
|
||||
year={2023},
|
||||
journal={arXiv:2306.10007}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"cmu_franka_exploration_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://human-world-model.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2308.10901",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{mendonca2023structured,
|
||||
title={Structured World Models from Human Videos},
|
||||
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
|
||||
journal={RSS},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"cmu_play_fusion": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://play-fusion.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2312.04549",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{chen2023playfusion,
|
||||
title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play},
|
||||
author={Chen, Lili and Bahl, Shikhar and Pathak, Deepak},
|
||||
booktitle={CoRL},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"cmu_stretch": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://robo-affordances.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2304.08488",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{bahl2023affordances,
|
||||
title={Affordances from Human Videos as a Versatile Representation for Robotics},
|
||||
author={Bahl, Shikhar and Mendonca, Russell and Chen, Lili and Jain, Unnat and Pathak, Deepak},
|
||||
booktitle={CVPR},
|
||||
year={2023}
|
||||
}
|
||||
@article{mendonca2023structured,
|
||||
title={Structured World Models from Human Videos},
|
||||
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
|
||||
journal={CoRL},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"columbia_cairlab_pusht_real": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://diffusion-policy.cs.columbia.edu/",
|
||||
"paper": "https://arxiv.org/abs/2303.04137v5",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{chi2023diffusionpolicy,
|
||||
title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
|
||||
author={Chi, Cheng and Feng, Siyuan and Du, Yilun and Xu, Zhenjia and Cousineau, Eric and Burchfiel, Benjamin and Song, Shuran},
|
||||
booktitle={Proceedings of Robotics: Science and Systems (RSS)},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"conq_hose_manipulation": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/conq-hose-manipulation-dataset/home",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{ConqHoseManipData,
|
||||
author={Peter Mitrano and Dmitry Berenson},
|
||||
title={Conq Hose Manipulation Dataset, v1.15.0},
|
||||
year={2024},
|
||||
howpublished={https://sites.google.com/view/conq-hose-manipulation-dataset}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"dlr_edan_shared_control": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://ieeexplore.ieee.org/document/9341156",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{vogel_edan_2020,
|
||||
title = {EDAN - an EMG-Controlled Daily Assistant to Help People with Physical Disabilities},
|
||||
language = {en},
|
||||
booktitle = {2020 {IEEE}/{RSJ} {International} {Conference} on {Intelligent} {Robots} and {Systems} ({IROS})},
|
||||
author = {Vogel, Jörn and Hagengruber, Annette and Iskandar, Maged and Quere, Gabriel and Leipscher, Ulrike and Bustamante, Samuel and Dietrich, Alexander and Hoeppner, Hannes and Leidner, Daniel and Albu-Schäffer, Alin},
|
||||
year = {2020}
|
||||
}
|
||||
@inproceedings{quere_shared_2020,
|
||||
address = {Paris, France},
|
||||
title = {Shared {Control} {Templates} for {Assistive} {Robotics}},
|
||||
language = {en},
|
||||
booktitle = {2020 {IEEE} {International} {Conference} on {Robotics} and {Automation} ({ICRA})},
|
||||
author = {Quere, Gabriel and Hagengruber, Annette and Iskandar, Maged and Bustamante, Samuel and Leidner, Daniel and Stulp, Freek and Vogel, Joern},
|
||||
year = {2020},
|
||||
pages = {7},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"dlr_sara_grid_clamp": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://www.researchsquare.com/article/rs-3289569/v1",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{padalkar2023guided,
|
||||
title={A guided reinforcement learning approach using shared control templates for learning manipulation skills in the real world},
|
||||
author={Padalkar, Abhishek and Quere, Gabriel and Raffin, Antonin and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
|
||||
journal={Research square preprint rs-3289569/v1},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"dlr_sara_pour": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://elib.dlr.de/193739/1/padalkar2023rlsct.pdf",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{padalkar2023guiding,
|
||||
title={Guiding Reinforcement Learning with Shared Control Templates},
|
||||
author={Padalkar, Abhishek and Quere, Gabriel and Steinmetz, Franz and Raffin, Antonin and Nieuwenhuisen, Matthias and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
|
||||
booktitle={40th IEEE International Conference on Robotics and Automation, ICRA 2023},
|
||||
year={2023},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"droid_100": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://droid-dataset.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2403.12945",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{khazatsky2024droid,
|
||||
title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset},
|
||||
author = {Alexander Khazatsky and Karl Pertsch and Suraj Nair and Ashwin Balakrishna and Sudeep Dasari and Siddharth Karamcheti and Soroush Nasiriany and Mohan Kumar Srirama and Lawrence Yunliang Chen and Kirsty Ellis and Peter David Fagan and Joey Hejna and Masha Itkina and Marion Lepert and Yecheng Jason Ma and Patrick Tree Miller and Jimmy Wu and Suneel Belkhale and Shivin Dass and Huy Ha and Arhan Jain and Abraham Lee and Youngwoon Lee and Marius Memmel and Sungjae Park and Ilija Radosavovic and Kaiyuan Wang and Albert Zhan and Kevin Black and Cheng Chi and Kyle Beltran Hatch and Shan Lin and Jingpei Lu and Jean Mercat and Abdul Rehman and Pannag R Sanketi and Archit Sharma and Cody Simpson and Quan Vuong and Homer Rich Walke and Blake Wulfe and Ted Xiao and Jonathan Heewon Yang and Arefeh Yavary and Tony Z. Zhao and Christopher Agia and Rohan Baijal and Mateo Guaman Castro and Daphne Chen and Qiuyu Chen and Trinity Chung and Jaimyn Drake and Ethan Paul Foster and Jensen Gao and David Antonio Herrera and Minho Heo and Kyle Hsu and Jiaheng Hu and Donovon Jackson and Charlotte Le and Yunshuang Li and Kevin Lin and Roy Lin and Zehan Ma and Abhiram Maddukuri and Suvir Mirchandani and Daniel Morton and Tony Nguyen and Abigail O'Neill and Rosario Scalise and Derick Seale and Victor Son and Stephen Tian and Emi Tran and Andrew E. Wang and Yilin Wu and Annie Xie and Jingyun Yang and Patrick Yin and Yunchu Zhang and Osbert Bastani and Glen Berseth and Jeannette Bohg and Ken Goldberg and Abhinav Gupta and Abhishek Gupta and Dinesh Jayaraman and Joseph J Lim and Jitendra Malik and Roberto Martín-Martín and Subramanian Ramamoorthy and Dorsa Sadigh and Shuran Song and Jiajun Wu and Michael C. Yip and Yuke Zhu and Thomas Kollar and Sergey Levine and Chelsea Finn},
|
||||
year = {2024},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"fmb": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://functional-manipulation-benchmark.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2401.08553",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{luo2024fmb,
|
||||
title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning},
|
||||
author={Luo, Jianlan and Xu, Charles and Liu, Fangchen and Tan, Liam and Lin, Zipeng and Wu, Jeffrey and Abbeel, Pieter and Levine, Sergey},
|
||||
journal={arXiv preprint arXiv:2401.08553},
|
||||
year={2024}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"iamlab_cmu_pickup_insert": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://openreview.net/forum?id=WuBv9-IGDUA",
|
||||
"paper": "https://arxiv.org/abs/2401.14502",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{saxena2023multiresolution,
|
||||
title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models},
|
||||
author={Saumya Saxena and Mohit Sharma and Oliver Kroemer},
|
||||
booktitle={7th Annual Conference on Robot Learning},
|
||||
year={2023},
|
||||
url={https://openreview.net/forum?id=WuBv9-IGDUA}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"imperialcollege_sawyer_wrist_cam": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
},
|
||||
"jaco_play": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://github.com/clvrai/clvr_jaco_play_dataset",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@software{dass2023jacoplay,
|
||||
author = {Dass, Shivin and Yapeter, Jullian and Zhang, Jesse and Zhang, Jiahui
|
||||
and Pertsch, Karl and Nikolaidis, Stefanos and Lim, Joseph J.},
|
||||
title = {CLVR Jaco Play Dataset},
|
||||
url = {https://github.com/clvrai/clvr_jaco_play_dataset},
|
||||
version = {1.0.0},
|
||||
year = {2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"kaist_nonprehensile": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://github.com/JaeHyung-Kim/rlds_dataset_builder",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{kimpre,
|
||||
title={Pre-and post-contact policy decomposition for non-prehensile manipulation with zero-shot sim-to-real transfer},
|
||||
author={Kim, Minchan and Han, Junhyek and Kim, Jaehyung and Kim, Beomjoon},
|
||||
booktitle={2023 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
|
||||
year={2023},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"nyu_door_opening_surprising_effectiveness": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://jyopari.github.io/VINN/",
|
||||
"paper": "https://arxiv.org/abs/2112.01511",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{pari2021surprising,
|
||||
title={The Surprising Effectiveness of Representation Learning for Visual Imitation},
|
||||
author={Jyothish Pari and Nur Muhammad Shafiullah and Sridhar Pandian Arunachalam and Lerrel Pinto},
|
||||
year={2021},
|
||||
eprint={2112.01511},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"nyu_franka_play_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://play-to-policy.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2210.10047",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{cui2022play,
|
||||
title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data},
|
||||
author = {Cui, Zichen Jeff and Wang, Yibin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
|
||||
journal = {arXiv preprint arXiv:2210.10047},
|
||||
year = {2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"nyu_rot_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://rot-robot.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2206.15469",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{haldar2023watch,
|
||||
title={Watch and match: Supercharging imitation with regularized optimal transport},
|
||||
author={Haldar, Siddhant and Mathur, Vaibhav and Yarats, Denis and Pinto, Lerrel},
|
||||
booktitle={Conference on Robot Learning},
|
||||
pages={32--43},
|
||||
year={2023},
|
||||
organization={PMLR}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"roboturk": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://roboturk.stanford.edu/dataset_real.html",
|
||||
"paper": "PAPER",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{mandlekar2019scaling,
|
||||
title={Scaling robot supervision to hundreds of hours with roboturk: Robotic manipulation dataset through human reasoning and dexterity},
|
||||
author={Mandlekar, Ajay and Booher, Jonathan and Spero, Max and Tung, Albert and Gupta, Anchit and Zhu, Yuke and Garg, Animesh and Savarese, Silvio and Fei-Fei, Li},
|
||||
booktitle={2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
|
||||
pages={1048--1055},
|
||||
year={2019},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"stanford_hydra_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/hydra-il-2023",
|
||||
"paper": "https://arxiv.org/abs/2306.17237",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{belkhale2023hydra,
|
||||
title={HYDRA: Hybrid Robot Actions for Imitation Learning},
|
||||
author={Belkhale, Suneel and Cui, Yuchen and Sadigh, Dorsa},
|
||||
journal={arxiv},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"stanford_kuka_multimodal_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/visionandtouch",
|
||||
"paper": "https://arxiv.org/abs/1810.10191",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{lee2019icra,
|
||||
title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks},
|
||||
author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette},
|
||||
booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)},
|
||||
year={2019},
|
||||
url={https://arxiv.org/abs/1810.10191}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"stanford_robocook": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://hshi74.github.io/robocook/",
|
||||
"paper": "https://arxiv.org/abs/2306.14447",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{shi2023robocook,
|
||||
title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools},
|
||||
author={Shi, Haochen and Xu, Huazhe and Clarke, Samuel and Li, Yunzhu and Wu, Jiajun},
|
||||
journal={arXiv preprint arXiv:2306.14447},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"taco_play": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://www.kaggle.com/datasets/oiermees/taco-robot",
|
||||
"paper": "https://arxiv.org/abs/2209.08959, https://arxiv.org/abs/2210.01911",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{rosete2022tacorl,
|
||||
author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard},
|
||||
title = {Latent Plans for Task Agnostic Offline Reinforcement Learning},
|
||||
journal = {Proceedings of the 6th Conference on Robot Learning (CoRL)},
|
||||
year = {2022}
|
||||
}
|
||||
@inproceedings{mees23hulc2,
|
||||
title={Grounding Language with Visual Affordances over Unstructured Data},
|
||||
author={Oier Mees and Jessica Borja-Diaz and Wolfram Burgard},
|
||||
booktitle = {Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)},
|
||||
year={2023},
|
||||
address = {London, UK}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"tokyo_u_lsmo": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "URL",
|
||||
"paper": "https://arxiv.org/abs/2107.05842",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@Article{Osa22,
|
||||
author = {Takayuki Osa},
|
||||
journal = {The International Journal of Robotics Research},
|
||||
title = {Motion Planning by Learning the Solution Manifold in Trajectory Optimization},
|
||||
year = {2022},
|
||||
number = {3},
|
||||
pages = {291--311},
|
||||
volume = {41},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"toto": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://toto-benchmark.org/",
|
||||
"paper": "https://arxiv.org/abs/2306.00942",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{zhou2023train,
|
||||
author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav},
|
||||
booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)},
|
||||
title={Train Offline, Test Online: A Real Robot Learning Benchmark},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"ucsd_kitchen_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@ARTICLE{ucsd_kitchens,
|
||||
author = {Ge Yan, Kris Wu, and Xiaolong Wang},
|
||||
title = {{ucsd kitchens Dataset}},
|
||||
year = {2023},
|
||||
month = {August}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"ucsd_pick_and_place_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://owmcorl.github.io/#",
|
||||
"paper": "https://arxiv.org/abs/2310.16029",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@preprint{Feng2023Finetuning,
|
||||
title={Finetuning Offline World Models in the Real World},
|
||||
author={Yunhai Feng, Nicklas Hansen, Ziyan Xiong, Chandramouli Rajagopalan, Xiaolong Wang},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"uiuc_d3field": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://robopil.github.io/d3fields/",
|
||||
"paper": "https://arxiv.org/abs/2309.16118",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{wang2023d3field,
|
||||
title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation},
|
||||
author={Wang, Yixuan and Li, Zhuoran and Zhang, Mingtong and Driggs-Campbell, Katherine and Wu, Jiajun and Fei-Fei, Li and Li, Yunzhu},
|
||||
journal={arXiv preprint arXiv:},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"usc_cloth_sim": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://uscresl.github.io/dmfd/",
|
||||
"paper": "https://arxiv.org/abs/2207.10148",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{salhotra2022dmfd,
|
||||
author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.},
|
||||
journal={IEEE Robotics and Automation Letters},
|
||||
title={Learning Deformable Object Manipulation From Expert Demonstrations},
|
||||
year={2022},
|
||||
volume={7},
|
||||
number={4},
|
||||
pages={8775-8782},
|
||||
doi={10.1109/LRA.2022.3187843}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utaustin_mutex": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/MUTEX/",
|
||||
"paper": "https://arxiv.org/abs/2309.14320",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{shah2023mutex,
|
||||
title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications},
|
||||
author={Rutav Shah and Roberto Mart{\'\i}n-Mart{\'\i}n and Yuke Zhu},
|
||||
booktitle={7th Annual Conference on Robot Learning},
|
||||
year={2023},
|
||||
url={https://openreview.net/forum?id=PwqiqaaEzJ}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_pr2_opening_fridge": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{oh2023pr2utokyodatasets,
|
||||
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
|
||||
title={X-Embodiment U-Tokyo PR2 Datasets},
|
||||
year={2023},
|
||||
url={https://github.com/ojh6404/rlds_dataset_builder},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_pr2_tabletop_manipulation": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{oh2023pr2utokyodatasets,
|
||||
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
|
||||
title={X-Embodiment U-Tokyo PR2 Datasets},
|
||||
year={2023},
|
||||
url={https://github.com/ojh6404/rlds_dataset_builder},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_saytap": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://saytap.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2306.07580",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{saytap2023,
|
||||
author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and
|
||||
Tatsuya Harada},
|
||||
title = {SayTap: Language to Quadrupedal Locomotion},
|
||||
eprint = {arXiv:2306.07580},
|
||||
url = {https://saytap.github.io},
|
||||
note = {https://saytap.github.io},
|
||||
year = {2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_xarm_bimanual": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{matsushima2023weblab,
|
||||
title={Weblab xArm Dataset},
|
||||
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_xarm_pick_and_place": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{matsushima2023weblab,
|
||||
title={Weblab xArm Dataset},
|
||||
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"viola": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/VIOLA/",
|
||||
"paper": "https://arxiv.org/abs/2210.11339",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{zhu2022viola,
|
||||
title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors},
|
||||
author={Zhu, Yifeng and Joshi, Abhishek and Stone, Peter and Zhu, Yuke},
|
||||
journal={6th Annual Conference on Robot Learning (CoRL)},
|
||||
year={2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def batch_convert():
|
||||
status = {}
|
||||
logfile = LOCAL_DIR / "conversion_log.txt"
|
||||
assert set(DATASETS) == {id_.split("/")[1] for id_ in available_datasets}
|
||||
for num, (name, kwargs) in enumerate(DATASETS.items()):
|
||||
repo_id = f"lerobot/{name}"
|
||||
print(f"\nConverting {repo_id} ({num}/{len(DATASETS)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
|
||||
status = f"{repo_id}: success."
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
continue
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_convert()
|
||||
@@ -1,774 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
|
||||
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
||||
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
|
||||
|
||||
We support 3 different scenarios for these tasks (see instructions below):
|
||||
1. Single task dataset: all episodes of your dataset have the same single task.
|
||||
2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
|
||||
one episode to the next.
|
||||
3. Multi task episodes: episodes of your dataset may each contain several different tasks.
|
||||
|
||||
|
||||
Can you can also provide a robot config .yaml file (not mandatory) to this script via the option
|
||||
'--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was
|
||||
recorded with. For now, only Aloha/Koch type robots are supported with this option.
|
||||
|
||||
|
||||
# 1. Single task dataset
|
||||
If your dataset contains a single task, you can simply provide it directly via the CLI with the
|
||||
'--single-task' option.
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||
--repo-id lerobot/aloha_sim_insertion_human_image \
|
||||
--single-task "Insert the peg into the socket." \
|
||||
--robot-config lerobot/configs/robot/aloha.yaml \
|
||||
--local-dir data
|
||||
```
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||
--repo-id aliberts/koch_tutorial \
|
||||
--single-task "Pick the Lego block and drop it in the box on the right." \
|
||||
--robot-config lerobot/configs/robot/koch.yaml \
|
||||
--local-dir data
|
||||
```
|
||||
|
||||
|
||||
# 2. Single task episodes
|
||||
If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
|
||||
|
||||
- If your dataset already contains a language instruction column in its parquet file, you can simply provide
|
||||
this column's name with the '--tasks-col' arg.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
||||
--tasks-col "language_instruction" \
|
||||
--local-dir data
|
||||
```
|
||||
|
||||
- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
|
||||
'--tasks-path' arg. This file should have the following structure where keys correspond to each
|
||||
episode_index in the dataset, and values are the language instruction for that episode.
|
||||
|
||||
Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"0": "Do something",
|
||||
"1": "Do something else",
|
||||
"2": "Do something",
|
||||
"3": "Go there",
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
# 3. Multi task episodes
|
||||
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
|
||||
parquet file, and you must provide this column's name with the '--tasks-col' arg.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
||||
--tasks-col "language_instruction" \
|
||||
--local-dir data
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import filecmp
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
create_branch,
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
get_hub_safe_version,
|
||||
load_json,
|
||||
unflatten_dict,
|
||||
write_json,
|
||||
write_jsonlines,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame, # noqa: F401
|
||||
get_image_pixel_channels,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
V16 = "v1.6"
|
||||
V20 = "v2.0"
|
||||
|
||||
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
|
||||
V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
|
||||
V1_INFO_PATH = "meta_data/info.json"
|
||||
V1_STATS_PATH = "meta_data/stats.safetensors"
|
||||
|
||||
|
||||
def parse_robot_config(
|
||||
config_path: Path, config_overrides: list[str] | None = None
|
||||
) -> tuple[str, dict]:
|
||||
robot_cfg = init_hydra_config(config_path, config_overrides)
|
||||
if robot_cfg["robot_type"] in ["aloha", "koch"]:
|
||||
state_names = [
|
||||
f"{arm}_{motor}" if len(robot_cfg["follower_arms"]) > 1 else motor
|
||||
for arm in robot_cfg["follower_arms"]
|
||||
for motor in robot_cfg["follower_arms"][arm]["motors"]
|
||||
]
|
||||
action_names = [
|
||||
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
|
||||
f"{arm}_{motor}" if len(robot_cfg["leader_arms"]) > 1 else motor
|
||||
for arm in robot_cfg["leader_arms"]
|
||||
for motor in robot_cfg["leader_arms"][arm]["motors"]
|
||||
]
|
||||
# elif robot_cfg["robot_type"] == "stretch3": TODO
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()."
|
||||
)
|
||||
|
||||
return {
|
||||
"robot_type": robot_cfg["robot_type"],
|
||||
"names": {
|
||||
"observation.state": state_names,
|
||||
"observation.effort": state_names,
|
||||
"action": action_names,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||
safetensor_path = v1_dir / V1_STATS_PATH
|
||||
stats = load_file(safetensor_path)
|
||||
serialized_stats = {key: value.tolist() for key, value in stats.items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
|
||||
json_path = v2_dir / STATS_PATH
|
||||
json_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(serialized_stats, f, indent=4)
|
||||
|
||||
# Sanity check
|
||||
with open(json_path) as f:
|
||||
stats_json = json.load(f)
|
||||
|
||||
stats_json = flatten_dict(stats_json)
|
||||
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
|
||||
for key in stats:
|
||||
torch.testing.assert_close(stats_json[key], stats[key])
|
||||
|
||||
|
||||
def get_features_from_hf_dataset(
|
||||
dataset: Dataset, robot_config: dict | None = None
|
||||
) -> dict[str, list]:
|
||||
features = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if isinstance(ft, datasets.Value):
|
||||
dtype = ft.dtype
|
||||
shape = (1,)
|
||||
names = None
|
||||
if isinstance(ft, datasets.Sequence):
|
||||
assert isinstance(ft.feature, datasets.Value)
|
||||
dtype = ft.feature.dtype
|
||||
shape = (ft.length,)
|
||||
motor_names = (
|
||||
robot_config["names"][key]
|
||||
if robot_config
|
||||
else [f"motor_{i}" for i in range(ft.length)]
|
||||
)
|
||||
assert len(motor_names) == shape[0]
|
||||
names = {"motors": motor_names}
|
||||
elif isinstance(ft, datasets.Image):
|
||||
dtype = "image"
|
||||
image = dataset[0][key] # Assuming first row
|
||||
channels = get_image_pixel_channels(image)
|
||||
shape = (image.height, image.width, channels)
|
||||
names = ["height", "width", "channel"]
|
||||
elif ft._type == "VideoFrame":
|
||||
dtype = "video"
|
||||
shape = None # Add shape later
|
||||
names = ["height", "width", "channel"]
|
||||
|
||||
features[key] = {
|
||||
"dtype": dtype,
|
||||
"shape": shape,
|
||||
"names": names,
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def add_task_index_by_episodes(
|
||||
dataset: Dataset, tasks_by_episodes: dict
|
||||
) -> tuple[Dataset, list[str]]:
|
||||
df = dataset.to_pandas()
|
||||
tasks = list(set(tasks_by_episodes.values()))
|
||||
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
|
||||
episodes_to_task_index = {
|
||||
ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
|
||||
|
||||
features = dataset.features
|
||||
features["task_index"] = datasets.Value(dtype="int64")
|
||||
dataset = Dataset.from_pandas(df, features=features, split="train")
|
||||
return dataset, tasks
|
||||
|
||||
|
||||
def add_task_index_from_tasks_col(
|
||||
dataset: Dataset, tasks_col: str
|
||||
) -> tuple[Dataset, dict[str, list[str]], list[str]]:
|
||||
df = dataset.to_pandas()
|
||||
|
||||
# HACK: This is to clean some of the instructions in our version of Open X datasets
|
||||
prefix_to_clean = "tf.Tensor(b'"
|
||||
suffix_to_clean = "', shape=(), dtype=string)"
|
||||
df[tasks_col] = (
|
||||
df[tasks_col]
|
||||
.str.removeprefix(prefix_to_clean)
|
||||
.str.removesuffix(suffix_to_clean)
|
||||
)
|
||||
|
||||
# Create task_index col
|
||||
tasks_by_episode = (
|
||||
df.groupby("episode_index")[tasks_col]
|
||||
.unique()
|
||||
.apply(lambda x: x.tolist())
|
||||
.to_dict()
|
||||
)
|
||||
tasks = df[tasks_col].unique().tolist()
|
||||
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
|
||||
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
|
||||
|
||||
# Build the dataset back from df
|
||||
features = dataset.features
|
||||
features["task_index"] = datasets.Value(dtype="int64")
|
||||
dataset = Dataset.from_pandas(df, features=features, split="train")
|
||||
dataset = dataset.remove_columns(tasks_col)
|
||||
|
||||
return dataset, tasks, tasks_by_episode
|
||||
|
||||
|
||||
def split_parquet_by_episodes(
|
||||
dataset: Dataset,
|
||||
total_episodes: int,
|
||||
total_chunks: int,
|
||||
output_dir: Path,
|
||||
) -> list:
|
||||
table = dataset.data.table
|
||||
episode_lengths = []
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(
|
||||
episode_chunk=ep_chunk
|
||||
)
|
||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
|
||||
episode_chunk=ep_chunk, episode_index=ep_idx
|
||||
)
|
||||
pq.write_table(ep_table, output_file)
|
||||
|
||||
return episode_lengths
|
||||
|
||||
|
||||
def move_videos(
|
||||
repo_id: str,
|
||||
video_keys: list[str],
|
||||
total_episodes: int,
|
||||
total_chunks: int,
|
||||
work_dir: Path,
|
||||
clean_gittatributes: Path,
|
||||
branch: str = "main",
|
||||
) -> None:
|
||||
"""
|
||||
HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
|
||||
commands to fetch git lfs video files references to move them into subdirectories without having to
|
||||
actually download them.
|
||||
"""
|
||||
_lfs_clone(repo_id, work_dir, branch)
|
||||
|
||||
videos_moved = False
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
|
||||
if len(video_files) == 0:
|
||||
video_files = [
|
||||
str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")
|
||||
]
|
||||
videos_moved = True # Videos have already been moved
|
||||
|
||||
assert len(video_files) == total_episodes * len(video_keys)
|
||||
|
||||
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
|
||||
|
||||
current_gittatributes = work_dir / ".gitattributes"
|
||||
if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False):
|
||||
fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes)
|
||||
|
||||
if lfs_untracked_videos:
|
||||
fix_lfs_video_files_tracking(work_dir, video_files)
|
||||
|
||||
if videos_moved:
|
||||
return
|
||||
|
||||
video_dirs = sorted(work_dir.glob("videos*/"))
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
for vid_key in video_keys:
|
||||
chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key
|
||||
)
|
||||
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
target_path = DEFAULT_VIDEO_PATH.format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
video_file = V1_VIDEO_FILE.format(
|
||||
video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
if len(video_dirs) == 1:
|
||||
video_path = video_dirs[0] / video_file
|
||||
else:
|
||||
for dir in video_dirs:
|
||||
if (dir / video_file).is_file():
|
||||
video_path = dir / video_file
|
||||
break
|
||||
|
||||
video_path.rename(work_dir / target_path)
|
||||
|
||||
commit_message = "Move video files into chunk subdirectories"
|
||||
subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_lfs_video_files_tracking(
|
||||
work_dir: Path, lfs_untracked_videos: list[str]
|
||||
) -> None:
|
||||
"""
|
||||
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
|
||||
there's no other option than to download the actual files and reupload them with lfs tracking.
|
||||
"""
|
||||
for i in range(0, len(lfs_untracked_videos), 100):
|
||||
files = lfs_untracked_videos[i : i + 100]
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "rm", "--cached", *files],
|
||||
cwd=work_dir,
|
||||
capture_output=True,
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("git rm --cached ERROR:")
|
||||
print(e.stderr)
|
||||
subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
|
||||
|
||||
commit_message = "Track video files with git lfs"
|
||||
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_gitattributes(
|
||||
work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path
|
||||
) -> None:
|
||||
shutil.copyfile(clean_gittatributes, current_gittatributes)
|
||||
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True
|
||||
)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
||||
subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
|
||||
repo_url = f"https://huggingface.co/datasets/{repo_id}"
|
||||
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
|
||||
subprocess.run(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"--branch",
|
||||
branch,
|
||||
"--single-branch",
|
||||
"--depth",
|
||||
"1",
|
||||
repo_url,
|
||||
str(work_dir),
|
||||
],
|
||||
check=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
|
||||
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
|
||||
lfs_tracked_files = subprocess.run(
|
||||
["git", "lfs", "ls-files", "-n"],
|
||||
cwd=work_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
|
||||
return [f for f in video_files if f not in lfs_tracked_files]
|
||||
|
||||
|
||||
def get_videos_info(
|
||||
repo_id: str, local_dir: Path, video_keys: list[str], branch: str
|
||||
) -> dict:
|
||||
# Assumes first episode
|
||||
video_files = [
|
||||
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
|
||||
for vid_key in video_keys
|
||||
]
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=local_dir,
|
||||
revision=branch,
|
||||
allow_patterns=video_files,
|
||||
)
|
||||
videos_info_dict = {}
|
||||
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
|
||||
|
||||
return videos_info_dict
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
local_dir: Path,
|
||||
single_task: str | None = None,
|
||||
tasks_path: Path | None = None,
|
||||
tasks_col: Path | None = None,
|
||||
robot_config: dict | None = None,
|
||||
test_branch: str | None = None,
|
||||
**card_kwargs,
|
||||
):
|
||||
v1 = get_hub_safe_version(repo_id, V16)
|
||||
v1x_dir = local_dir / V16 / repo_id
|
||||
v20_dir = local_dir / V20 / repo_id
|
||||
v1x_dir.mkdir(parents=True, exist_ok=True)
|
||||
v20_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
revision=v1,
|
||||
local_dir=v1x_dir,
|
||||
ignore_patterns="videos*/",
|
||||
)
|
||||
branch = "main"
|
||||
if test_branch:
|
||||
branch = test_branch
|
||||
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
|
||||
|
||||
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
||||
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
||||
features = get_features_from_hf_dataset(dataset, robot_config)
|
||||
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||
|
||||
if single_task and "language_instruction" in dataset.column_names:
|
||||
logging.warning(
|
||||
"'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
|
||||
)
|
||||
single_task = None
|
||||
tasks_col = "language_instruction"
|
||||
|
||||
# Episodes & chunks
|
||||
episode_indices = sorted(dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
assert episode_indices == list(range(total_episodes))
|
||||
total_videos = total_episodes * len(video_keys)
|
||||
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
|
||||
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
|
||||
total_chunks += 1
|
||||
|
||||
# Tasks
|
||||
if single_task:
|
||||
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {
|
||||
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
elif tasks_path:
|
||||
tasks_by_episodes = load_json(tasks_path)
|
||||
tasks_by_episodes = {
|
||||
int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {
|
||||
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
|
||||
}
|
||||
elif tasks_col:
|
||||
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(
|
||||
dataset, tasks_col
|
||||
)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
assert set(tasks) == {
|
||||
task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks
|
||||
}
|
||||
tasks = [
|
||||
{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)
|
||||
]
|
||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||
features["task_index"] = {
|
||||
"dtype": "int64",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
|
||||
# Videos
|
||||
if video_keys:
|
||||
assert metadata_v1.get("video", False)
|
||||
dataset = dataset.remove_columns(video_keys)
|
||||
clean_gitattr = Path(
|
||||
hub_api.hf_hub_download(
|
||||
repo_id=GITATTRIBUTES_REF,
|
||||
repo_type="dataset",
|
||||
local_dir=local_dir,
|
||||
filename=".gitattributes",
|
||||
)
|
||||
).absolute()
|
||||
with tempfile.TemporaryDirectory() as tmp_video_dir:
|
||||
move_videos(
|
||||
repo_id,
|
||||
video_keys,
|
||||
total_episodes,
|
||||
total_chunks,
|
||||
Path(tmp_video_dir),
|
||||
clean_gitattr,
|
||||
branch,
|
||||
)
|
||||
videos_info = get_videos_info(
|
||||
repo_id, v1x_dir, video_keys=video_keys, branch=branch
|
||||
)
|
||||
for key in video_keys:
|
||||
features[key]["shape"] = (
|
||||
videos_info[key].pop("video.height"),
|
||||
videos_info[key].pop("video.width"),
|
||||
videos_info[key].pop("video.channels"),
|
||||
)
|
||||
features[key]["video_info"] = videos_info[key]
|
||||
assert math.isclose(
|
||||
videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3
|
||||
)
|
||||
if "encoding" in metadata_v1:
|
||||
assert (
|
||||
videos_info[key]["video.pix_fmt"]
|
||||
== metadata_v1["encoding"]["pix_fmt"]
|
||||
)
|
||||
else:
|
||||
assert metadata_v1.get("video", 0) == 0
|
||||
videos_info = None
|
||||
|
||||
# Split data into 1 parquet file by episode
|
||||
episode_lengths = split_parquet_by_episodes(
|
||||
dataset, total_episodes, total_chunks, v20_dir
|
||||
)
|
||||
|
||||
if robot_config is not None:
|
||||
robot_type = robot_config["robot_type"]
|
||||
repo_tags = [robot_type]
|
||||
else:
|
||||
robot_type = "unknown"
|
||||
repo_tags = None
|
||||
|
||||
# Episodes
|
||||
episodes = [
|
||||
{
|
||||
"episode_index": ep_idx,
|
||||
"tasks": tasks_by_episodes[ep_idx],
|
||||
"length": episode_lengths[ep_idx],
|
||||
}
|
||||
for ep_idx in episode_indices
|
||||
]
|
||||
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
||||
|
||||
# Assemble metadata v2.0
|
||||
metadata_v2_0 = {
|
||||
"codebase_version": V20,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": total_episodes,
|
||||
"total_frames": len(dataset),
|
||||
"total_tasks": len(tasks),
|
||||
"total_videos": total_videos,
|
||||
"total_chunks": total_chunks,
|
||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"fps": metadata_v1["fps"],
|
||||
"splits": {"train": f"0:{total_episodes}"},
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
|
||||
"features": features,
|
||||
}
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs
|
||||
)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch
|
||||
)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id,
|
||||
path_in_repo="meta_data",
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch
|
||||
)
|
||||
|
||||
hub_api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
path_in_repo="data",
|
||||
folder_path=v20_dir / "data",
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
)
|
||||
hub_api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
path_in_repo="meta",
|
||||
folder_path=v20_dir / "meta",
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
)
|
||||
|
||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
if not test_branch:
|
||||
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
task_args = parser.add_mutually_exclusive_group(required=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--single-task",
|
||||
type=str,
|
||||
help="A short but accurate description of the single task performed in the dataset.",
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--tasks-col",
|
||||
type=str,
|
||||
help="The name of the column containing language instructions",
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--tasks-path",
|
||||
type=Path,
|
||||
help="The path to a .json file containing one language instruction for each episode_index",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-config",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Path to the robot's config yaml the dataset during conversion.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override the robot config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--license",
|
||||
type=str,
|
||||
default="apache-2.0",
|
||||
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.local_dir:
|
||||
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
||||
|
||||
robot_config = (
|
||||
parse_robot_config(args.robot_config, args.robot_overrides)
|
||||
if args.robot_config
|
||||
else None
|
||||
)
|
||||
del args.robot_config, args.robot_overrides
|
||||
|
||||
convert_dataset(**vars(args), robot_config=robot_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import warnings
|
||||
@@ -26,11 +25,47 @@ import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def load_from_videos(
|
||||
item: dict[str, torch.Tensor],
|
||||
video_frame_keys: list[str],
|
||||
videos_dir: Path,
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
):
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
|
||||
This probably happens because a memory reference to the video loader is created in the main process and a
|
||||
subprocess fails to access it.
|
||||
"""
|
||||
# since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
|
||||
data_dir = videos_dir.parent
|
||||
|
||||
for key in video_frame_keys:
|
||||
if isinstance(item[key], list):
|
||||
# load multiple frames at once (expected when delta_timestamps is not None)
|
||||
timestamps = [frame["timestamp"] for frame in item[key]]
|
||||
paths = [frame["path"] for frame in item[key]]
|
||||
if len(set(paths)) > 1:
|
||||
raise NotImplementedError("All video paths are expected to be the same for now.")
|
||||
video_path = data_dir / paths[0]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
item[key] = frames
|
||||
else:
|
||||
# load one frame
|
||||
timestamps = [item[key]["timestamp"]]
|
||||
video_path = data_dir / item[key]["path"]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
item[key] = frames[0]
|
||||
|
||||
return item
|
||||
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
video_path: Path | str,
|
||||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
@@ -128,8 +163,8 @@ def decode_video_frames_torchvision(
|
||||
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
@@ -212,110 +247,3 @@ with warnings.catch_warnings():
|
||||
)
|
||||
# to make VideoFrame available in HuggingFace `datasets`
|
||||
register_feature(VideoFrame, "VideoFrame")
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
ffprobe_audio_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"a:0",
|
||||
"-show_entries",
|
||||
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(
|
||||
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||
if audio_stream_info is None:
|
||||
return {"has_audio": False}
|
||||
|
||||
# Return the information, defaulting to None if no audio stream is present
|
||||
return {
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"])
|
||||
if audio_stream_info.get("bit_rate")
|
||||
else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||
}
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
ffprobe_video_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"v:0",
|
||||
"-show_entries",
|
||||
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(
|
||||
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
video_stream_info = info["streams"][0]
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
r_frame_rate = video_stream_info["r_frame_rate"]
|
||||
num, denom = map(int, r_frame_rate.split("/"))
|
||||
fps = num / denom
|
||||
|
||||
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
||||
|
||||
video_info = {
|
||||
"video.fps": fps,
|
||||
"video.height": video_stream_info["height"],
|
||||
"video.width": video_stream_info["width"],
|
||||
"video.channels": pixel_channels,
|
||||
"video.codec": video_stream_info["codec_name"],
|
||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||
"video.is_depth_map": False,
|
||||
**get_audio_info(video_path),
|
||||
}
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
||||
return 4
|
||||
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
||||
return 3
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_image_pixel_channels(image: Image):
|
||||
if image.mode == "L":
|
||||
return 1 # Grayscale
|
||||
elif image.mode == "LA":
|
||||
return 2 # Grayscale + Alpha
|
||||
elif image.mode == "RGB":
|
||||
return 3 # RGB
|
||||
elif image.mode == "RGBA":
|
||||
return 4 # RGBA
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
@@ -14,13 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from collections import deque
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
# from mani_skill.utils import common
|
||||
|
||||
|
||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
@@ -34,12 +30,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
if cfg.env.name == "real_world":
|
||||
return
|
||||
|
||||
if "maniskill" in cfg.env.name:
|
||||
env = make_maniskill_env(
|
||||
cfg, n_envs if n_envs is not None else cfg.eval.batch_size
|
||||
)
|
||||
return env
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
|
||||
try:
|
||||
@@ -57,11 +47,7 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
|
||||
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = (
|
||||
gym.vector.AsyncVectorEnv
|
||||
if cfg.eval.use_async_envs
|
||||
else gym.vector.SyncVectorEnv
|
||||
)
|
||||
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
|
||||
@@ -70,99 +56,3 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def make_maniskill_env(
|
||||
cfg: DictConfig, n_envs: int | None = None
|
||||
) -> gym.vector.VectorEnv | None:
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
env = gym.make(
|
||||
cfg.env.task,
|
||||
obs_mode=cfg.env.obs,
|
||||
control_mode=cfg.env.control_mode,
|
||||
render_mode=cfg.env.render_mode,
|
||||
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
|
||||
num_envs=n_envs,
|
||||
)
|
||||
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True)
|
||||
# state should have the size of 25
|
||||
# env = ConvertToLeRobotEnv(env, n_envs)
|
||||
# env = PixelWrapper(cfg, env, n_envs)
|
||||
env._max_episode_steps = env.max_episode_steps = (
|
||||
50 # gym_utils.find_max_episode_steps_value(env)
|
||||
)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class PixelWrapper(gym.Wrapper):
|
||||
"""
|
||||
Wrapper for pixel observations. Works with Maniskill vectorized environments
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, env, num_envs, num_frames=3):
|
||||
super().__init__(env)
|
||||
self.cfg = cfg
|
||||
self.env = env
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
self._frames = deque([], maxlen=num_frames)
|
||||
self._render_size = cfg.env.render_size
|
||||
|
||||
def _get_obs(self, obs):
|
||||
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
|
||||
self._frames.append(frame)
|
||||
return {
|
||||
"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(
|
||||
self.env.device
|
||||
)
|
||||
}
|
||||
|
||||
def reset(self, seed):
|
||||
obs, info = self.env.reset() # (seed=seed)
|
||||
for _ in range(self._frames.maxlen):
|
||||
obs_frames = self._get_obs(obs)
|
||||
return obs_frames, info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
|
||||
|
||||
# TODO: Remove this
|
||||
class ConvertToLeRobotEnv(gym.Wrapper):
|
||||
def __init__(self, env, num_envs):
|
||||
super().__init__(env)
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
obs, info = self.env.reset(seed=seed, options={})
|
||||
return self._get_obs(obs), info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
|
||||
def _get_obs(self, observation):
|
||||
sensor_data = observation.pop("sensor_data")
|
||||
del observation["sensor_param"]
|
||||
images = []
|
||||
for cam_data in sensor_data.values():
|
||||
images.append(cam_data["rgb"])
|
||||
|
||||
images = torch.concat(images, axis=-1)
|
||||
# flatten the rest of the data which should just be state data
|
||||
observation = common.flatten_state_dict(
|
||||
observation, use_torch=True, device=self.base_env.device
|
||||
)
|
||||
ret = dict()
|
||||
ret["state"] = observation
|
||||
ret["pixels"] = images
|
||||
return ret
|
||||
|
||||
@@ -28,32 +28,28 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
for key, img in observations.items():
|
||||
if "images" not in key:
|
||||
continue
|
||||
if "pixels" in observations:
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
|
||||
if img.ndim == 3:
|
||||
img = img.unsqueeze(0)
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert (
|
||||
c < h and c < w
|
||||
), f"expect channel last images, but instead got {img.shape=}"
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
return_observations[key] = img
|
||||
# obs state agent qpos and qvel
|
||||
# image
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
return_observations[imgkey] = img
|
||||
|
||||
if "environment_state" in observations:
|
||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||
@@ -62,43 +58,5 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
# requirement for "agent_pos"
|
||||
# return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return_observations["observation.state"] = observations["observation.state"].float()
|
||||
return return_observations
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(
|
||||
observations: dict[str, np.ndarray],
|
||||
) -> dict[str, Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
Returns:
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
q_pos = observations["agent"]["qpos"]
|
||||
q_vel = observations["agent"]["qvel"]
|
||||
tcp_pos = observations["extra"]["tcp_pose"]
|
||||
img = observations["sensor_data"]["base_camera"]["rgb"]
|
||||
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
||||
|
||||
return_observations["observation.image"] = img
|
||||
return_observations["observation.state"] = state
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return return_observations
|
||||
|
||||
@@ -25,7 +25,6 @@ from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import wandb
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
@@ -84,9 +83,7 @@ class Logger:
|
||||
pretrained_model_dir_name = "pretrained_model"
|
||||
training_state_file_name = "training_state.pth"
|
||||
|
||||
def __init__(
|
||||
self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None
|
||||
):
|
||||
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
|
||||
"""
|
||||
Args:
|
||||
log_dir: The directory to save all logs and training outputs to.
|
||||
@@ -106,12 +103,12 @@ class Logger:
|
||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||
run_offline = not enable_wandb or not project
|
||||
if run_offline:
|
||||
logging.info(
|
||||
colored("Logs will be saved locally.", "yellow", attrs=["bold"])
|
||||
)
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = None
|
||||
if cfg.resume:
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
|
||||
@@ -131,12 +128,8 @@ class Logger:
|
||||
job_type="train_eval",
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
# Handle custom step key for rl asynchronous training.
|
||||
self._wandb_custom_step_key: set[str] | None = None
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(
|
||||
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
|
||||
)
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
@classmethod
|
||||
@@ -157,9 +150,7 @@ class Logger:
|
||||
"""
|
||||
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
|
||||
|
||||
def save_model(
|
||||
self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None
|
||||
):
|
||||
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
|
||||
"""Save the weights of the Policy model using PyTorchModelHubMixin.
|
||||
|
||||
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
|
||||
@@ -182,32 +173,18 @@ class Logger:
|
||||
self,
|
||||
save_dir: Path,
|
||||
train_step: int,
|
||||
optimizer: Optimizer | dict,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
interaction_step: int | None = None,
|
||||
):
|
||||
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
|
||||
|
||||
All of these are saved as "training_state.pth" under the checkpoint directory.
|
||||
"""
|
||||
# In Sac, for example, we have a dictionary of torch.optim.Optimizer
|
||||
if type(optimizer) is dict:
|
||||
optimizer_state_dict = {}
|
||||
for k in optimizer:
|
||||
optimizer_state_dict[k] = optimizer[k].state_dict()
|
||||
else:
|
||||
optimizer_state_dict = optimizer.state_dict()
|
||||
|
||||
training_state = {
|
||||
"step": train_step,
|
||||
"optimizer": optimizer_state_dict,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
**get_global_random_state(),
|
||||
}
|
||||
# Interaction step is related to the distributed training code
|
||||
# In that setup, we have two kinds of steps, the online step of the env and the optimization step
|
||||
# We need to save both in order to resume the optimization properly and not break the logs dependant on the interaction step
|
||||
if interaction_step is not None:
|
||||
training_state["interaction_step"] = interaction_step
|
||||
if scheduler is not None:
|
||||
training_state["scheduler"] = scheduler.state_dict()
|
||||
torch.save(training_state, save_dir / self.training_state_file_name)
|
||||
@@ -219,7 +196,6 @@ class Logger:
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
identifier: str,
|
||||
interaction_step: int | None = None,
|
||||
):
|
||||
"""Checkpoint the model weights and the training state."""
|
||||
checkpoint_dir = self.checkpoints_dir / str(identifier)
|
||||
@@ -229,34 +205,18 @@ class Logger:
|
||||
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
|
||||
)
|
||||
self.save_model(
|
||||
checkpoint_dir / self.pretrained_model_dir_name,
|
||||
policy,
|
||||
wandb_artifact_name=wandb_artifact_name,
|
||||
)
|
||||
self.save_training_state(
|
||||
checkpoint_dir, train_step, optimizer, scheduler, interaction_step
|
||||
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
|
||||
)
|
||||
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
|
||||
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
|
||||
|
||||
def load_last_training_state(
|
||||
self, optimizer: Optimizer | dict, scheduler: LRScheduler | None
|
||||
) -> int:
|
||||
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
|
||||
"""
|
||||
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
|
||||
random state, and return the global training step.
|
||||
"""
|
||||
training_state = torch.load(
|
||||
self.last_checkpoint_dir / self.training_state_file_name
|
||||
)
|
||||
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
|
||||
if type(training_state["optimizer"]) is dict:
|
||||
assert set(training_state["optimizer"].keys()) == set(
|
||||
optimizer.keys()
|
||||
), "Optimizer dictionaries do not have the same keys during resume!"
|
||||
for k, v in training_state["optimizer"].items():
|
||||
optimizer[k].load_state_dict(v)
|
||||
else:
|
||||
optimizer.load_state_dict(training_state["optimizer"])
|
||||
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
|
||||
optimizer.load_state_dict(training_state["optimizer"])
|
||||
if scheduler is not None:
|
||||
scheduler.load_state_dict(training_state["scheduler"])
|
||||
elif "scheduler" in training_state:
|
||||
@@ -264,63 +224,20 @@ class Logger:
|
||||
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
|
||||
)
|
||||
# Small hack to get the expected keys: use `get_global_random_state`.
|
||||
set_global_random_state(
|
||||
{k: training_state[k] for k in get_global_random_state()}
|
||||
)
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"]
|
||||
|
||||
def log_dict(
|
||||
self,
|
||||
d,
|
||||
step: int | None = None,
|
||||
mode="train",
|
||||
custom_step_key: str | None = None,
|
||||
):
|
||||
"""Log a dictionary of metrics to WandB."""
|
||||
def log_dict(self, d, step, mode="train"):
|
||||
assert mode in {"train", "eval"}
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
if step is None and custom_step_key is None:
|
||||
raise ValueError("Either step or custom_step_key must be provided.")
|
||||
|
||||
if self._wandb is not None:
|
||||
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
|
||||
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
||||
# multiple time steps is possible for example, the interaction step with the environment,
|
||||
# the training step, the evaluation step, etc. So we need to define a custom step key
|
||||
# to log the correct step for each metric.
|
||||
if custom_step_key is not None:
|
||||
if self._wandb_custom_step_key is None:
|
||||
self._wandb_custom_step_key = set()
|
||||
new_custom_key = f"{mode}/{custom_step_key}"
|
||||
if new_custom_key not in self._wandb_custom_step_key:
|
||||
self._wandb_custom_step_key.add(new_custom_key)
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str, wandb.Table)):
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
|
||||
# Do not log the custom step key itself.
|
||||
if (
|
||||
self._wandb_custom_step_key is not None
|
||||
and k in self._wandb_custom_step_key
|
||||
):
|
||||
continue
|
||||
|
||||
if custom_step_key is not None:
|
||||
value_custom_step = d[custom_step_key]
|
||||
self._wandb.log(
|
||||
{
|
||||
f"{mode}/{k}": v,
|
||||
f"{mode}/{custom_step_key}": value_custom_step,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
|
||||
@@ -168,6 +168,4 @@ class ACTConfig:
|
||||
not any(k.startswith("observation.image") for k in self.input_shapes)
|
||||
and "observation.environment_state" not in self.input_shapes
|
||||
):
|
||||
raise ValueError(
|
||||
"You must provide at least one image or the environment state among the inputs."
|
||||
)
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
@@ -81,14 +81,10 @@ class ACTPolicy(
|
||||
|
||||
self.model = ACT(config)
|
||||
|
||||
self.expected_image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
|
||||
if config.temporal_ensemble_coeff is not None:
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(
|
||||
config.temporal_ensemble_coeff, config.chunk_size
|
||||
)
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -111,12 +107,8 @@ class ACTPolicy(
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
@@ -143,18 +135,13 @@ class ACTPolicy(
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
* ~batch["action_is_pad"].unsqueeze(-1)
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
@@ -164,12 +151,7 @@ class ACTPolicy(
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(
|
||||
-0.5
|
||||
* (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())
|
||||
)
|
||||
.sum(-1)
|
||||
.mean()
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
@@ -223,9 +205,7 @@ class ACTTemporalEnsembler:
|
||||
```
|
||||
"""
|
||||
self.chunk_size = chunk_size
|
||||
self.ensemble_weights = torch.exp(
|
||||
-temporal_ensemble_coeff * torch.arange(chunk_size)
|
||||
)
|
||||
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||
self.reset()
|
||||
|
||||
@@ -241,9 +221,7 @@ class ACTTemporalEnsembler:
|
||||
time steps, and pop/return the next batch of actions in the sequence.
|
||||
"""
|
||||
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
|
||||
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(
|
||||
device=actions.device
|
||||
)
|
||||
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
|
||||
if self.ensembled_actions is None:
|
||||
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||
# time step of the episode.
|
||||
@@ -251,34 +229,19 @@ class ACTTemporalEnsembler:
|
||||
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
|
||||
# operations later.
|
||||
self.ensembled_actions_count = torch.ones(
|
||||
(self.chunk_size, 1),
|
||||
dtype=torch.long,
|
||||
device=self.ensembled_actions.device,
|
||||
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
|
||||
)
|
||||
else:
|
||||
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||
# the online update for those entries.
|
||||
self.ensembled_actions *= self.ensemble_weights_cumsum[
|
||||
self.ensembled_actions_count - 1
|
||||
]
|
||||
self.ensembled_actions += (
|
||||
actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
|
||||
)
|
||||
self.ensembled_actions /= self.ensemble_weights_cumsum[
|
||||
self.ensembled_actions_count
|
||||
]
|
||||
self.ensembled_actions_count = torch.clamp(
|
||||
self.ensembled_actions_count + 1, max=self.chunk_size
|
||||
)
|
||||
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
|
||||
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
|
||||
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
|
||||
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
|
||||
# The last action, which has no prior online average, needs to get concatenated onto the end.
|
||||
self.ensembled_actions = torch.cat(
|
||||
[self.ensembled_actions, actions[:, -1:]], dim=1
|
||||
)
|
||||
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
|
||||
self.ensembled_actions_count = torch.cat(
|
||||
[
|
||||
self.ensembled_actions_count,
|
||||
torch.ones_like(self.ensembled_actions_count[-1:]),
|
||||
]
|
||||
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
|
||||
)
|
||||
# "Consume" the first action.
|
||||
action, self.ensembled_actions, self.ensembled_actions_count = (
|
||||
@@ -330,9 +293,7 @@ class ACT(nn.Module):
|
||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
self.use_robot_state = "observation.state" in config.input_shapes
|
||||
self.use_images = any(
|
||||
k.startswith("observation.image") for k in config.input_shapes
|
||||
)
|
||||
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
if self.config.use_vae:
|
||||
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
|
||||
@@ -347,9 +308,7 @@ class ACT(nn.Module):
|
||||
config.output_shapes["action"][0], config.dim_model
|
||||
)
|
||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(
|
||||
config.dim_model, config.latent_dim * 2
|
||||
)
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
||||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||
# dimension.
|
||||
num_input_token_encoder = 1 + config.chunk_size
|
||||
@@ -357,28 +316,20 @@ class ACT(nn.Module):
|
||||
num_input_token_encoder += 1
|
||||
self.register_buffer(
|
||||
"vae_encoder_pos_enc",
|
||||
create_sinusoidal_pos_embedding(
|
||||
num_input_token_encoder, config.dim_model
|
||||
).unsqueeze(0),
|
||||
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
if self.use_images:
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
replace_stride_with_dilation=[
|
||||
False,
|
||||
False,
|
||||
config.replace_final_stride_with_dilation,
|
||||
],
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
||||
weights=config.pretrained_backbone_weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
|
||||
# feature map).
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.backbone = IntermediateLayerGetter(
|
||||
backbone_model, return_layers={"layer4": "feature_map"}
|
||||
)
|
||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
|
||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||
self.encoder = ACTEncoder(config)
|
||||
@@ -392,8 +343,7 @@ class ACT(nn.Module):
|
||||
)
|
||||
if self.use_env_state:
|
||||
self.encoder_env_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0],
|
||||
config.dim_model,
|
||||
config.input_shapes["observation.environment_state"][0], config.dim_model
|
||||
)
|
||||
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
||||
if self.use_images:
|
||||
@@ -408,18 +358,14 @@ class ACT(nn.Module):
|
||||
n_1d_tokens += 1
|
||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||
if self.use_images:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(
|
||||
config.dim_model // 2
|
||||
)
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
|
||||
|
||||
# Final action regression head on the output of the transformer's decoder.
|
||||
self.action_head = nn.Linear(
|
||||
config.dim_model, config.output_shapes["action"][0]
|
||||
)
|
||||
self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0])
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
@@ -429,9 +375,7 @@ class ACT(nn.Module):
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(
|
||||
self, batch: dict[str, Tensor]
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
||||
|
||||
`batch` should have the following structure:
|
||||
@@ -468,20 +412,12 @@ class ACT(nn.Module):
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.use_robot_state:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(
|
||||
batch["observation.state"]
|
||||
)
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(
|
||||
batch["action"]
|
||||
) # (B, S, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
|
||||
if self.use_robot_state:
|
||||
vae_encoder_input = [
|
||||
cls_embed,
|
||||
robot_state_embed,
|
||||
action_embed,
|
||||
] # (B, S+2, D)
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
else:
|
||||
vae_encoder_input = [cls_embed, action_embed]
|
||||
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
||||
@@ -519,26 +455,20 @@ class ACT(nn.Module):
|
||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||
mu = log_sigma_x2 = None
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
latent_sample = torch.zeros(
|
||||
[batch_size, self.config.latent_dim], dtype=torch.float32
|
||||
).to(batch["observation.state"].device)
|
||||
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
||||
batch["observation.state"].device
|
||||
)
|
||||
|
||||
# Prepare transformer encoder inputs.
|
||||
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
|
||||
encoder_in_pos_embed = list(
|
||||
self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)
|
||||
)
|
||||
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
||||
# Robot state token.
|
||||
if self.use_robot_state:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_robot_state_input_proj(batch["observation.state"])
|
||||
)
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
# Environment state token.
|
||||
if self.use_env_state:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_env_state_input_proj(
|
||||
batch["observation.environment_state"]
|
||||
)
|
||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||
)
|
||||
|
||||
# Camera observation features and positional embeddings.
|
||||
@@ -547,29 +477,19 @@ class ACT(nn.Module):
|
||||
all_cam_pos_embeds = []
|
||||
|
||||
for cam_index in range(batch["observation.images"].shape[-4]):
|
||||
cam_features = self.backbone(batch["observation.images"][:, cam_index])[
|
||||
"feature_map"
|
||||
]
|
||||
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
|
||||
# buffer
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(
|
||||
dtype=cam_features.dtype
|
||||
)
|
||||
cam_features = self.encoder_img_feat_input_proj(
|
||||
cam_features
|
||||
) # (B, C, h, w)
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
all_cam_features.append(cam_features)
|
||||
all_cam_pos_embeds.append(cam_pos_embed)
|
||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
|
||||
# and move to (sequence, batch, dim).
|
||||
all_cam_features = torch.cat(all_cam_features, axis=-1)
|
||||
encoder_in_tokens.extend(
|
||||
einops.rearrange(all_cam_features, "b c h w -> (h w) b c")
|
||||
)
|
||||
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
|
||||
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||
encoder_in_pos_embed.extend(
|
||||
einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c")
|
||||
)
|
||||
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
|
||||
|
||||
# Stack all tokens along the sequence dimension.
|
||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||
@@ -604,21 +524,12 @@ class ACTEncoder(nn.Module):
|
||||
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
|
||||
super().__init__()
|
||||
self.is_vae_encoder = is_vae_encoder
|
||||
num_layers = (
|
||||
config.n_vae_encoder_layers
|
||||
if self.is_vae_encoder
|
||||
else config.n_encoder_layers
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[ACTEncoderLayer(config) for _ in range(num_layers)]
|
||||
)
|
||||
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
|
||||
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
|
||||
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
pos_embed: Tensor | None = None,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
|
||||
) -> Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
|
||||
@@ -629,9 +540,7 @@ class ACTEncoder(nn.Module):
|
||||
class ACTEncoderLayer(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(
|
||||
config.dim_model, config.n_heads, dropout=config.dropout
|
||||
)
|
||||
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
|
||||
@@ -646,9 +555,7 @@ class ACTEncoderLayer(nn.Module):
|
||||
self.activation = get_activation_fn(config.feedforward_activation)
|
||||
self.pre_norm = config.pre_norm
|
||||
|
||||
def forward(
|
||||
self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
|
||||
) -> Tensor:
|
||||
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
|
||||
skip = x
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
@@ -673,9 +580,7 @@ class ACTDecoder(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
"""Convenience module for running multiple decoder layers followed by normalization."""
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]
|
||||
)
|
||||
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
|
||||
self.norm = nn.LayerNorm(config.dim_model)
|
||||
|
||||
def forward(
|
||||
@@ -687,10 +592,7 @@ class ACTDecoder(nn.Module):
|
||||
) -> Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(
|
||||
x,
|
||||
encoder_out,
|
||||
decoder_pos_embed=decoder_pos_embed,
|
||||
encoder_pos_embed=encoder_pos_embed,
|
||||
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
|
||||
)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
@@ -700,12 +602,8 @@ class ACTDecoder(nn.Module):
|
||||
class ACTDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(
|
||||
config.dim_model, config.n_heads, dropout=config.dropout
|
||||
)
|
||||
self.multihead_attn = nn.MultiheadAttention(
|
||||
config.dim_model, config.n_heads, dropout=config.dropout
|
||||
)
|
||||
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
|
||||
@@ -746,9 +644,7 @@ class ACTDecoderLayer(nn.Module):
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
|
||||
x = self.self_attn(q, k, value=x)[
|
||||
0
|
||||
] # select just the output, not the attention weights
|
||||
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
||||
x = skip + self.dropout1(x)
|
||||
if self.pre_norm:
|
||||
skip = x
|
||||
@@ -785,14 +681,9 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso
|
||||
"""
|
||||
|
||||
def get_position_angle_vec(position):
|
||||
return [
|
||||
position / np.power(10000, 2 * (hid_j // 2) / dimension)
|
||||
for hid_j in range(dimension)
|
||||
]
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
|
||||
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos_i) for pos_i in range(num_positions)]
|
||||
)
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
return torch.from_numpy(sinusoid_table).float()
|
||||
@@ -837,9 +728,7 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
||||
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
|
||||
|
||||
inverse_frequency = self._temperature ** (
|
||||
2
|
||||
* (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2)
|
||||
/ self.dimension
|
||||
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
|
||||
)
|
||||
|
||||
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
||||
@@ -847,15 +736,9 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
||||
|
||||
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
|
||||
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
|
||||
pos_embed_x = torch.stack(
|
||||
(x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1
|
||||
).flatten(3)
|
||||
pos_embed_y = torch.stack(
|
||||
(y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1
|
||||
).flatten(3)
|
||||
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(
|
||||
0, 3, 1, 2
|
||||
) # (1, C, H, W)
|
||||
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
|
||||
|
||||
return pos_embed
|
||||
|
||||
|
||||
@@ -121,9 +121,7 @@ class DiffusionConfig:
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
@@ -165,13 +163,8 @@ class DiffusionConfig:
|
||||
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
|
||||
if (
|
||||
len(image_keys) == 0
|
||||
and "observation.environment_state" not in self.input_shapes
|
||||
):
|
||||
raise ValueError(
|
||||
"You must provide at least one image or the environment state among the inputs."
|
||||
)
|
||||
if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if len(image_keys) > 0:
|
||||
if self.crop_shape is not None:
|
||||
|
||||
@@ -88,9 +88,7 @@ class DiffusionPolicy(
|
||||
|
||||
self.diffusion = DiffusionModel(config)
|
||||
|
||||
self.expected_image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
|
||||
self.reset()
|
||||
@@ -104,9 +102,7 @@ class DiffusionPolicy(
|
||||
if len(self.expected_image_keys) > 0:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.use_env_state:
|
||||
self._queues["observation.environment_state"] = deque(
|
||||
maxlen=self.config.n_obs_steps
|
||||
)
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -132,22 +128,14 @@ class DiffusionPolicy(
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {
|
||||
k: torch.stack(list(self._queues[k]), dim=1)
|
||||
for k in batch
|
||||
if k in self._queues
|
||||
}
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
@@ -162,12 +150,8 @@ class DiffusionPolicy(
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
@@ -193,9 +177,7 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
# Build observation encoders (depending on which observations are provided).
|
||||
global_cond_dim = config.input_shapes["observation.state"][0]
|
||||
num_images = len(
|
||||
[k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
)
|
||||
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||
self._use_images = False
|
||||
self._use_env_state = False
|
||||
if num_images > 0:
|
||||
@@ -211,9 +193,7 @@ class DiffusionModel(nn.Module):
|
||||
self._use_env_state = True
|
||||
global_cond_dim += config.input_shapes["observation.environment_state"][0]
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(
|
||||
config, global_cond_dim=global_cond_dim * config.n_obs_steps
|
||||
)
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
@@ -233,21 +213,14 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
self,
|
||||
batch_size: int,
|
||||
global_cond: Tensor | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
|
||||
) -> Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
dtype = get_dtype_from_parameters(self)
|
||||
|
||||
# Sample prior.
|
||||
sample = torch.randn(
|
||||
size=(
|
||||
batch_size,
|
||||
self.config.horizon,
|
||||
self.config.output_shapes["action"][0],
|
||||
),
|
||||
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -263,9 +236,7 @@ class DiffusionModel(nn.Module):
|
||||
global_cond=global_cond,
|
||||
)
|
||||
# Compute previous image: x_t -> x_t-1
|
||||
sample = self.noise_scheduler.step(
|
||||
model_output, t, sample, generator=generator
|
||||
).prev_sample
|
||||
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
@@ -277,39 +248,27 @@ class DiffusionModel(nn.Module):
|
||||
if self._use_images:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(
|
||||
batch["observation.images"], "b s n ... -> n (b s) ..."
|
||||
)
|
||||
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||
img_features_list = torch.cat(
|
||||
[
|
||||
encoder(images)
|
||||
for encoder, images in zip(
|
||||
self.rgb_encoder, images_per_camera, strict=True
|
||||
)
|
||||
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
|
||||
]
|
||||
)
|
||||
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features_list,
|
||||
"(n b s) ... -> b s (n ...)",
|
||||
b=batch_size,
|
||||
s=n_obs_steps,
|
||||
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
else:
|
||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(
|
||||
batch["observation.images"], "b s n ... -> (b s n) ..."
|
||||
)
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features,
|
||||
"(b s n) ... -> b s (n ...)",
|
||||
b=batch_size,
|
||||
s=n_obs_steps,
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
global_cond_feats.append(img_features)
|
||||
|
||||
@@ -395,9 +354,7 @@ class DiffusionModel(nn.Module):
|
||||
elif self.config.prediction_type == "sample":
|
||||
target = batch["action"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported prediction type {self.config.prediction_type}"
|
||||
)
|
||||
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(pred, target, reduction="none")
|
||||
|
||||
@@ -457,9 +414,7 @@ class SpatialSoftmax(nn.Module):
|
||||
|
||||
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||
# and causes a small degradation in pc_success of pre-trained models.
|
||||
pos_x, pos_y = np.meshgrid(
|
||||
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
|
||||
)
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||
# register as buffer so it's moved to the correct device.
|
||||
@@ -501,9 +456,7 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(
|
||||
config.crop_shape
|
||||
)
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -524,9 +477,7 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16, num_channels=x.num_features
|
||||
),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
)
|
||||
|
||||
# Set up pooling and final layers.
|
||||
@@ -534,25 +485,17 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.input_shapes`.
|
||||
image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
image_key = image_keys[0]
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape
|
||||
if config.crop_shape is not None
|
||||
else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(
|
||||
size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)
|
||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
self.pool = SpatialSoftmax(
|
||||
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
|
||||
)
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
@@ -579,9 +522,7 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module,
|
||||
predicate: Callable[[nn.Module], bool],
|
||||
func: Callable[[nn.Module], nn.Module],
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Args:
|
||||
@@ -594,11 +535,7 @@ def _replace_submodules(
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
replace_list = [
|
||||
k.split(".")
|
||||
for k, m in root_module.named_modules(remove_duplicate=True)
|
||||
if predicate(m)
|
||||
]
|
||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
for *parents, k in replace_list:
|
||||
parent_module = root_module
|
||||
if len(parents) > 0:
|
||||
@@ -613,9 +550,7 @@ def _replace_submodules(
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
assert not any(
|
||||
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
|
||||
)
|
||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||
return root_module
|
||||
|
||||
|
||||
@@ -643,9 +578,7 @@ class DiffusionConv1dBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
|
||||
),
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
nn.Mish(),
|
||||
)
|
||||
@@ -668,13 +601,9 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
# Encoder for the diffusion timestep.
|
||||
self.diffusion_step_encoder = nn.Sequential(
|
||||
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
|
||||
nn.Linear(
|
||||
config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4
|
||||
),
|
||||
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(
|
||||
config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim
|
||||
),
|
||||
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
|
||||
)
|
||||
|
||||
# The FiLM conditioning dimension.
|
||||
@@ -699,16 +628,10 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
self.down_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_in, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_out, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
# Downsample as long as it is not the last block.
|
||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1)
|
||||
if not is_last
|
||||
else nn.Identity(),
|
||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -717,14 +640,10 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
self.mid_modules = nn.ModuleList(
|
||||
[
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
config.down_dims[-1],
|
||||
config.down_dims[-1],
|
||||
**common_res_block_kwargs,
|
||||
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
config.down_dims[-1],
|
||||
config.down_dims[-1],
|
||||
**common_res_block_kwargs,
|
||||
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||
),
|
||||
]
|
||||
)
|
||||
@@ -737,24 +656,16 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
nn.ModuleList(
|
||||
[
|
||||
# dim_in * 2, because it takes the encoder's skip connection as well
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_in * 2, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
dim_out, dim_out, **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
# Upsample as long as it is not the last block.
|
||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1)
|
||||
if not is_last
|
||||
else nn.Identity(),
|
||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
DiffusionConv1dBlock(
|
||||
config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size
|
||||
),
|
||||
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
|
||||
nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1),
|
||||
)
|
||||
|
||||
@@ -822,23 +733,17 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
|
||||
self.use_film_scale_modulation = use_film_scale_modulation
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.conv1 = DiffusionConv1dBlock(
|
||||
in_channels, out_channels, kernel_size, n_groups=n_groups
|
||||
)
|
||||
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
|
||||
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
||||
|
||||
self.conv2 = DiffusionConv1dBlock(
|
||||
out_channels, out_channels, kernel_size, n_groups=n_groups
|
||||
)
|
||||
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# A final convolution for dimension matching the residual (if needed).
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(in_channels, out_channels, 1)
|
||||
if in_channels != out_channels
|
||||
else nn.Identity()
|
||||
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
||||
|
||||
@@ -51,10 +51,15 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
return TDMPCPolicy, TDMPCConfig
|
||||
|
||||
elif name == "tdmpc2":
|
||||
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
|
||||
from lerobot.common.policies.tdmpc2.modeling_tdmpc2 import TDMPC2Policy
|
||||
|
||||
return TDMPC2Policy, TDMPC2Config
|
||||
|
||||
elif name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import (
|
||||
DiffusionConfig,
|
||||
)
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
return DiffusionPolicy, DiffusionConfig
|
||||
@@ -68,21 +73,12 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
|
||||
return VQBeTPolicy, VQBeTConfig
|
||||
elif name == "sac":
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
return SACPolicy, SACConfig
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
||||
def make_policy(
|
||||
hydra_cfg: DictConfig,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
dataset_stats=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None
|
||||
) -> Policy:
|
||||
"""Make an instance of a policy class.
|
||||
|
||||
@@ -96,19 +92,17 @@ def make_policy(
|
||||
be provided when initializing a new policy, and must not be provided when loading a pretrained
|
||||
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
|
||||
"""
|
||||
# if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
|
||||
# raise ValueError(
|
||||
# "Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
|
||||
# )
|
||||
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
|
||||
raise ValueError(
|
||||
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
|
||||
)
|
||||
|
||||
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||
if pretrained_policy_name_or_path is None:
|
||||
# Make a fresh policy.
|
||||
# HACK: We pass *args and **kwargs to the policy constructor to allow for additional arguments
|
||||
# for example device for the sac policy.
|
||||
policy = policy_cls(config=policy_cfg, dataset_stats=dataset_stats)
|
||||
policy = policy_cls(policy_cfg, dataset_stats)
|
||||
else:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
@@ -117,9 +111,7 @@ def make_policy(
|
||||
# huggingface_hub should make it possible to avoid the hack:
|
||||
# https://github.com/huggingface/huggingface_hub/pull/2274.
|
||||
policy = policy_cls(policy_cfg)
|
||||
policy.load_state_dict(
|
||||
policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict()
|
||||
)
|
||||
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
|
||||
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassifierConfig:
|
||||
"""Configuration for the Classifier model."""
|
||||
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "helper2424/resnet10"
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
|
||||
def save_pretrained(self, save_dir):
|
||||
"""Save config to json file."""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Convert to dict and save as JSON
|
||||
config_dict = asdict(self)
|
||||
with open(os.path.join(save_dir, "config.json"), "w") as f:
|
||||
json.dump(config_dict, f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path):
|
||||
"""Load config from json file."""
|
||||
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
|
||||
|
||||
with open(config_file) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
return cls(**config_dict)
|
||||
@@ -1,173 +0,0 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .configuration_classifier import ClassifierConfig
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logits: Tensor,
|
||||
probabilities: Optional[Tensor] = None,
|
||||
hidden_states: Optional[Tensor] = None,
|
||||
):
|
||||
self.logits = logits
|
||||
self.probabilities = probabilities
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})"
|
||||
)
|
||||
|
||||
|
||||
class Classifier(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
# Add Hub metadata
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "vision-classifier"],
|
||||
):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
# Add name attribute for factory
|
||||
name = "classifier"
|
||||
|
||||
def __init__(self, config: ClassifierConfig):
|
||||
from transformers import AutoModel
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
encoder = AutoModel.from_pretrained(
|
||||
self.config.model_name, trust_remote_code=True
|
||||
)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
self.encoder = encoder.vision_model
|
||||
self.vision_config = encoder.config.vision_config
|
||||
else:
|
||||
self.encoder = encoder
|
||||
self.vision_config = getattr(encoder, "config", None)
|
||||
|
||||
# Model type from config
|
||||
self.is_cnn = self.config.model_type == "cnn"
|
||||
|
||||
# For CNNs, initialize backbone
|
||||
if self.is_cnn:
|
||||
self._setup_cnn_backbone()
|
||||
|
||||
self._freeze_encoder()
|
||||
self._build_classifier_head()
|
||||
|
||||
def _setup_cnn_backbone(self):
|
||||
"""Set up CNN encoder"""
|
||||
if hasattr(self.encoder, "fc"):
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[
|
||||
-1
|
||||
] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
self.encoder = self.encoder.to(self.config.device)
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _build_classifier_head(self) -> None:
|
||||
"""Initialize the classifier head architecture."""
|
||||
# Get input dimension based on model type
|
||||
if self.is_cnn:
|
||||
input_dim = self.feature_dim
|
||||
else: # Transformer models
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported transformer architecture since hidden_size is not found"
|
||||
)
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.LayerNorm(self.config.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(
|
||||
self.config.hidden_dim,
|
||||
1 if self.config.num_classes == 2 else self.config.num_classes,
|
||||
),
|
||||
)
|
||||
self.classifier_head = self.classifier_head.to(self.config.device)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
# Process images with the processor (handles resizing and normalization)
|
||||
# processed = self.processor(
|
||||
# images=x, # LeRobotDataset already provides proper tensor format
|
||||
# return_tensors="pt",
|
||||
# )
|
||||
# processed = processed["pixel_values"].to(x.device)
|
||||
processed = x
|
||||
|
||||
with torch.no_grad():
|
||||
if self.is_cnn:
|
||||
# The HF ResNet applies pooling internally
|
||||
outputs = self.encoder(processed)
|
||||
# Get pooled output directly
|
||||
features = outputs.pooler_output
|
||||
|
||||
if features.dim() > 2:
|
||||
features = features.squeeze(-1).squeeze(-1)
|
||||
return features
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(processed)
|
||||
if (
|
||||
hasattr(outputs, "pooler_output")
|
||||
and outputs.pooler_output is not None
|
||||
):
|
||||
return outputs.pooler_output
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
def forward(self, xs: torch.Tensor) -> ClassifierOutput:
|
||||
"""Forward pass of the classifier."""
|
||||
# For training, we expect input to be a tensor directly from LeRobotDataset
|
||||
encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs])
|
||||
logits = self.classifier_head(encoder_outputs)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
logits = logits.squeeze(-1)
|
||||
probabilities = torch.sigmoid(logits)
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(
|
||||
logits=logits, probabilities=probabilities, hidden_states=encoder_outputs
|
||||
)
|
||||
|
||||
def predict_reward(self, x, threshold=0.6):
|
||||
if self.config.num_classes == 2:
|
||||
probs = self.forward(x).probabilities
|
||||
logging.debug(f"Predicted reward images: {probs}")
|
||||
return (probs > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.forward(x).probabilities, dim=1)
|
||||
@@ -1,29 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
|
||||
class HILSerlPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "hilserl"],
|
||||
):
|
||||
pass
|
||||
@@ -130,7 +130,7 @@ class Normalize(nn.Module):
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
# @torch.no_grad
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, mode in self.modes.items():
|
||||
@@ -196,7 +196,7 @@ class Unnormalize(nn.Module):
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
# @torch.no_grad
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, mode in self.modes.items():
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SACConfig:
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 84, 84],
|
||||
"observation.state": [4],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
"observation.environment_state": "min_max",
|
||||
}
|
||||
)
|
||||
input_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": {
|
||||
"mean": [[0.485, 0.456, 0.406]],
|
||||
"std": [[0.229, 0.224, 0.225]],
|
||||
},
|
||||
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"}
|
||||
)
|
||||
output_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": {"min": [-1, -1], "max": [1, 1]},
|
||||
}
|
||||
)
|
||||
# TODO: Move it outside of the config
|
||||
actor_learner_config: dict[str, str | int] = field(
|
||||
default_factory=lambda: {
|
||||
"learner_host": "127.0.0.1",
|
||||
"learner_port": 50051,
|
||||
}
|
||||
)
|
||||
camera_number: int = 1
|
||||
|
||||
storage_device: str = "cpu"
|
||||
# Add type annotations for these fields:
|
||||
vision_encoder_name: str | None = field(default="helper2424/resnet10")
|
||||
freeze_vision_encoder: bool = True
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
discount: float = 0.99
|
||||
temperature_init: float = 1.0
|
||||
num_critics: int = 2
|
||||
num_subsample_critics: int | None = None
|
||||
critic_lr: float = 3e-4
|
||||
actor_lr: float = 3e-4
|
||||
temperature_lr: float = 3e-4
|
||||
critic_target_update_weight: float = 0.005
|
||||
utd_ratio: int = 1 # If you want enable utd_ratio, you need to set it to >1
|
||||
state_encoder_hidden_dim: int = 256
|
||||
latent_dim: int = 256
|
||||
target_entropy: float | None = None
|
||||
use_backup_entropy: bool = True
|
||||
grad_clip_norm: float = 40.0
|
||||
critic_network_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
"final_activation": None,
|
||||
}
|
||||
)
|
||||
actor_network_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
)
|
||||
policy_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"use_tanh_squash": True,
|
||||
"log_std_min": -5,
|
||||
"log_std_max": 2,
|
||||
"init_final": 0.05,
|
||||
}
|
||||
)
|
||||
@@ -1,981 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO: (1) better device management
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple, Union, Dict, List
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "RL", "SAC"],
|
||||
):
|
||||
name = "sac"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = SACConfig()
|
||||
self.config = config
|
||||
|
||||
if config.input_normalization_modes is not None:
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(
|
||||
config.input_normalization_params
|
||||
)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes,
|
||||
config.input_normalization_modes,
|
||||
input_normalization_params,
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(
|
||||
config.output_normalization_params
|
||||
)
|
||||
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
# NOTE: For images the encoder should be shared between the actor and critic
|
||||
if config.shared_encoder:
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor: SACObservationEncoder = encoder_critic
|
||||
else:
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
||||
|
||||
# Create a list of critic heads
|
||||
critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
ensemble=critic_heads,
|
||||
output_normalization=self.normalize_targets,
|
||||
)
|
||||
|
||||
# Create target critic heads as deepcopies of the original critic heads
|
||||
target_critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
ensemble=target_critic_heads,
|
||||
output_normalization=self.normalize_targets,
|
||||
)
|
||||
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(
|
||||
input_dim=encoder_actor.output_dim, **config.actor_network_kwargs
|
||||
),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = (
|
||||
-np.prod(config.output_shapes["action"][0]) / 2
|
||||
) # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
# it triggers "can't optimize a non-leaf Tensor"
|
||||
|
||||
temperature_init = config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
def _save_pretrained(self, save_directory):
|
||||
"""Custom save method to handle TensorDict properly"""
|
||||
import os
|
||||
import json
|
||||
from dataclasses import asdict
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||
from safetensors.torch import save_model
|
||||
|
||||
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
|
||||
|
||||
# Save config
|
||||
config_dict = asdict(self.config)
|
||||
with open(os.path.join(save_directory, CONFIG_NAME), "w") as f:
|
||||
json.dump(config_dict, f, indent=2)
|
||||
print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}")
|
||||
|
||||
@classmethod
|
||||
def _from_pretrained(
|
||||
cls,
|
||||
*,
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
cache_dir: Optional[Union[str, Path]],
|
||||
force_download: bool,
|
||||
proxies: Optional[Dict],
|
||||
resume_download: Optional[bool],
|
||||
local_files_only: bool,
|
||||
token: Optional[Union[str, bool]],
|
||||
map_location: str = "cpu",
|
||||
strict: bool = False,
|
||||
**model_kwargs,
|
||||
) -> "SACPolicy":
|
||||
"""Custom load method to handle loading SAC policy from saved files"""
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||
from safetensors.torch import load_model
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
|
||||
# Check if model_id is a local path or a hub model ID
|
||||
if os.path.isdir(model_id):
|
||||
model_path = Path(model_id)
|
||||
safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE)
|
||||
config_file = os.path.join(model_path, CONFIG_NAME)
|
||||
else:
|
||||
# Download the safetensors file from the hub
|
||||
safetensors_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
# Download the config file
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except Exception:
|
||||
config_file = None
|
||||
|
||||
# Load or create config
|
||||
if config_file and os.path.exists(config_file):
|
||||
# Load config from file
|
||||
with open(config_file) as f:
|
||||
config_dict = json.load(f)
|
||||
config = SACConfig(**config_dict)
|
||||
else:
|
||||
# Use the provided config or create a default one
|
||||
config = model_kwargs.get("config", SACConfig())
|
||||
|
||||
# Create a new instance with the loaded config
|
||||
model = cls(config=config)
|
||||
|
||||
# Load state dict from safetensors file
|
||||
if os.path.exists(safetensors_file):
|
||||
load_model(model, filename=safetensors_file, device=map_location)
|
||||
|
||||
return model
|
||||
|
||||
def reset(self):
|
||||
"""Reset the policy"""
|
||||
pass
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
|
||||
if self.actor.fixed_std is not None:
|
||||
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
|
||||
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
|
||||
super().to(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
actions, _, _ = self.actor(batch)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
def critic_forward(
|
||||
self,
|
||||
observations: dict[str, Tensor],
|
||||
actions: Tensor,
|
||||
use_target: bool = False,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(),
|
||||
self.critic_ensemble.parameters(),
|
||||
strict=False,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features: Tensor | None = None,
|
||||
next_observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(
|
||||
next_observations, next_observation_features
|
||||
)
|
||||
|
||||
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
||||
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[
|
||||
"action"
|
||||
]
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(
|
||||
observations,
|
||||
actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_temperature(
|
||||
self, observations, observation_features: Tensor | None = None
|
||||
) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations, observation_features)
|
||||
temperature_loss = (
|
||||
-self.log_alpha.exp() * (log_probs + self.config.target_entropy)
|
||||
).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(
|
||||
self, observations, observation_features: Tensor | None = None
|
||||
) -> Tensor:
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
||||
actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"]
|
||||
|
||||
q_preds = self.critic_forward(
|
||||
observations,
|
||||
actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.activate_final = activate_final
|
||||
layers = []
|
||||
|
||||
# First layer uses input_dim
|
||||
layers.append(nn.Linear(input_dim, hidden_dims[0]))
|
||||
|
||||
# Add activation after first layer
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[0]))
|
||||
layers.append(
|
||||
activations
|
||||
if isinstance(activations, nn.Module)
|
||||
else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
# Rest of the layers
|
||||
for i in range(1, len(hidden_dims)):
|
||||
layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
|
||||
|
||||
if i + 1 < len(hidden_dims) or activate_final:
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||
|
||||
# If we're at the final layer and a final activation is specified, use it
|
||||
if (
|
||||
i + 1 == len(hidden_dims)
|
||||
and activate_final
|
||||
and final_activation is not None
|
||||
):
|
||||
layers.append(
|
||||
final_activation
|
||||
if isinstance(final_activation, nn.Module)
|
||||
else getattr(nn, final_activation)()
|
||||
)
|
||||
else:
|
||||
layers.append(
|
||||
activations
|
||||
if isinstance(activations, nn.Module)
|
||||
else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CriticHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
init_final: Optional[float] = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.net = MLP(
|
||||
input_dim=input_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
activations=activations,
|
||||
activate_final=activate_final,
|
||||
dropout_rate=dropout_rate,
|
||||
final_activation=final_activation,
|
||||
)
|
||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.output_layer(self.net(x))
|
||||
|
||||
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
┌──────────────────┬─────────────────────────────────────────────────────────┐
|
||||
│ Critic Ensemble │ │
|
||||
├──────────────────┘ │
|
||||
│ │
|
||||
│ ┌────┐ ┌────┐ ┌────┐ │
|
||||
│ │ Q1 │ │ Q2 │ │ Qn │ │
|
||||
│ └────┘ └────┘ └────┘ │
|
||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ │ MLP 1 │ │ MLP 2 │ │ MLP │ │
|
||||
│ │ │ │ │ ... │ num_critics │ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||
│ ▲ ▲ ▲ │
|
||||
│ └───────────────────┴───────┬────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
│ ┌───────────────────┐ │
|
||||
│ │ Embedding │ │
|
||||
│ │ │ │
|
||||
│ └───────────────────┘ │
|
||||
│ ▲ │
|
||||
│ │ │
|
||||
│ ┌─────────────┴────────────┐ │
|
||||
│ │ │ │
|
||||
│ │ SACObservationEncoder │ │
|
||||
│ │ │ │
|
||||
│ └──────────────────────────┘ │
|
||||
│ ▲ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
└───────────────────────────┬────────────────────┬───────────────────────────┘
|
||||
│ Observation │
|
||||
└────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
ensemble: List[CriticHead],
|
||||
output_normalization: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.init_final = init_final
|
||||
self.output_normalization = output_normalization
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
# Handle the case where a part of the encoder if frozen
|
||||
if self.encoder is not None:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
||||
self.parameters_to_optimize += list(self.critics.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
# NOTE: We normalize actions it helps for sample efficiency
|
||||
actions: dict[str, torch.tensor] = {"action": actions}
|
||||
# NOTE: Normalization layer took dict in input and outputs a dict that why
|
||||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = (
|
||||
observation_features
|
||||
if observation_features is not None
|
||||
else (observations if self.encoder is None else self.encoder(observations))
|
||||
)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
|
||||
# Loop through critics and collect outputs
|
||||
q_values = []
|
||||
for critic in self.critics:
|
||||
q_values.append(critic(inputs))
|
||||
|
||||
# Stack outputs to match expected shape [num_critics, batch_size]
|
||||
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
|
||||
return q_values
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
action_dim: int,
|
||||
log_std_min: float = -5,
|
||||
log_std_max: float = 2,
|
||||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
use_tanh_squash: bool = False,
|
||||
encoder_is_shared: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.log_std_min = log_std_min
|
||||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
self.parameters_to_optimize += list(self.network.parameters())
|
||||
|
||||
if self.encoder is not None and not encoder_is_shared:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters())
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
# Mean layer
|
||||
self.mean_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.mean_layer.weight)
|
||||
|
||||
self.parameters_to_optimize += list(self.mean_layer.parameters())
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
self.parameters_to_optimize += list(self.std_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = (
|
||||
observation_features
|
||||
if observation_features is not None
|
||||
else (observations if self.encoder is None else self.encoder(observations))
|
||||
)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
means = self.mean_layer(outputs)
|
||||
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(
|
||||
log_std
|
||||
).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (
|
||||
self.log_std_max - self.log_std_min
|
||||
) * (log_std + 1.0)
|
||||
else:
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
else:
|
||||
log_std = self.fixed_std.expand_as(means)
|
||||
|
||||
# uses tanh activation function to squash the action to be in the range of [-1, 1]
|
||||
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
||||
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
|
||||
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
|
||||
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log(
|
||||
(1 - actions.pow(2)) + 1e-6
|
||||
) # Adjust log-probs for Tanh
|
||||
else:
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
|
||||
log_probs = log_probs.sum(-1) # Sum over action dimensions
|
||||
means = torch.tanh(means) if self.use_tanh_squash else means
|
||||
return actions, log_probs, means
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
"""Get encoded features from observations"""
|
||||
device = get_device_from_parameters(self)
|
||||
observations = observations.to(device)
|
||||
if self.encoder is not None:
|
||||
with torch.inference_mode():
|
||||
return self.encoder(observations)
|
||||
return observations
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self.has_pretrained_vision_encoder = False
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
self.aggregation_size: int = 0
|
||||
if any("observation.image" in key for key in config.input_shapes):
|
||||
self.camera_number = config.camera_number
|
||||
|
||||
if self.config.vision_encoder_name is not None:
|
||||
self.image_enc_layers = PretrainedImageEncoder(config)
|
||||
self.has_pretrained_vision_encoder = True
|
||||
else:
|
||||
self.image_enc_layers = DefaultImageEncoder(config)
|
||||
|
||||
self.aggregation_size += config.latent_dim * self.camera_number
|
||||
|
||||
if config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.all_image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_shapes["observation.state"][0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_shapes["observation.environment_state"][0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_layer = nn.Linear(
|
||||
in_features=self.aggregation_size, out_features=config.latent_dim
|
||||
)
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
over all features.
|
||||
"""
|
||||
feat = []
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
# Batch all images along the batch dimension, then encode them.
|
||||
if len(self.all_image_keys) > 0:
|
||||
images_batched = torch.cat(
|
||||
[obs_dict[key] for key in self.all_image_keys], dim=0
|
||||
)
|
||||
images_batched = self.image_enc_layers(images_batched)
|
||||
embeddings_chunks = torch.chunk(
|
||||
images_batched, dim=0, chunks=len(self.all_image_keys)
|
||||
)
|
||||
feat.extend(embeddings_chunks)
|
||||
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(
|
||||
self.env_state_enc_layers(obs_dict["observation.environment_state"])
|
||||
)
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
features = self.aggregation_layer(features)
|
||||
|
||||
return features
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
"""Returns the dimension of the encoder output"""
|
||||
return self.config.latent_dim
|
||||
|
||||
|
||||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=config.input_shapes["observation.image"][0],
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=5,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.image_enc_layers(x)
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = (
|
||||
self._load_pretrained_vision_encoder(config)
|
||||
)
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
self.image_enc_layers = AutoModel.from_pretrained(
|
||||
config.vision_encoder_name, trust_remote_code=True
|
||||
)
|
||||
# self.image_enc_layers.pooler = Identity()
|
||||
|
||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[
|
||||
-1
|
||||
] # Last channel dimension
|
||||
elif hasattr(self.image_enc_layers, "fc"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported vision encoder architecture, make sure you are using a CNN"
|
||||
)
|
||||
return self.image_enc_layers, self.image_enc_out_shape
|
||||
|
||||
def forward(self, x):
|
||||
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
|
||||
# doesn't reach the classifier layer because we don't need it
|
||||
enc_feat = self.image_enc_layers(x).pooler_output
|
||||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
||||
return enc_feat
|
||||
|
||||
|
||||
def freeze_image_encoder(image_encoder: nn.Module):
|
||||
"""Freeze all parameters in the encoder"""
|
||||
for param in image_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
converted_params = {}
|
||||
for outer_key, inner_dict in normalization_params.items():
|
||||
converted_params[outer_key] = {}
|
||||
for key, value in inner_dict.items():
|
||||
converted_params[outer_key][key] = torch.tensor(value)
|
||||
if "image" in outer_key:
|
||||
converted_params[outer_key][key] = converted_params[outer_key][
|
||||
key
|
||||
].view(3, 1, 1)
|
||||
|
||||
return converted_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Benchmark the CriticEnsemble performance
|
||||
import time
|
||||
|
||||
# Configuration
|
||||
num_critics = 10
|
||||
batch_size = 32
|
||||
action_dim = 7
|
||||
obs_dim = 64
|
||||
hidden_dims = [256, 256]
|
||||
num_iterations = 100
|
||||
|
||||
print("Creating test environment...")
|
||||
|
||||
# Create a simple dummy encoder
|
||||
class DummyEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.output_dim = obs_dim
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
def forward(self, obs):
|
||||
# Just return a random tensor of the right shape
|
||||
# In practice, this would encode the observations
|
||||
return torch.randn(batch_size, obs_dim, device=device)
|
||||
|
||||
# Create critic heads
|
||||
print(f"Creating {num_critics} critic heads...")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=obs_dim + action_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
).to(device)
|
||||
for _ in range(num_critics)
|
||||
]
|
||||
|
||||
# Create the critic ensemble
|
||||
print("Creating CriticEnsemble...")
|
||||
critic_ensemble = CriticEnsemble(
|
||||
encoder=DummyEncoder().to(device),
|
||||
ensemble=critic_heads,
|
||||
output_normalization=nn.Identity(),
|
||||
).to(device)
|
||||
|
||||
# Create random input data
|
||||
print("Creating input data...")
|
||||
obs_dict = {
|
||||
"observation.state": torch.randn(batch_size, obs_dim, device=device),
|
||||
}
|
||||
actions = torch.randn(batch_size, action_dim, device=device)
|
||||
|
||||
# Warmup run
|
||||
print("Warming up...")
|
||||
_ = critic_ensemble(obs_dict, actions)
|
||||
|
||||
# Time the forward pass
|
||||
print(f"Running benchmark with {num_iterations} iterations...")
|
||||
start_time = time.perf_counter()
|
||||
for _ in range(num_iterations):
|
||||
q_values = critic_ensemble(obs_dict, actions)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
# Print results
|
||||
elapsed_time = end_time - start_time
|
||||
print(f"Total time: {elapsed_time:.4f} seconds")
|
||||
print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
|
||||
print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
|
||||
|
||||
# Verify that all critic heads produce different outputs
|
||||
# This confirms each critic head is unique
|
||||
# print("\nVerifying critic outputs are different:")
|
||||
# for i in range(num_critics):
|
||||
# for j in range(i + 1, num_critics):
|
||||
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
||||
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
||||
@@ -191,10 +191,6 @@ class TDMPCConfig:
|
||||
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
|
||||
)
|
||||
if not self.use_mpc:
|
||||
raise ValueError(
|
||||
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
|
||||
)
|
||||
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError(
|
||||
"`n_action_steps` must be less than or equal to `horizon`."
|
||||
)
|
||||
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||
|
||||
@@ -68,9 +68,7 @@ class TDMPCPolicy(
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TDMPCConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -102,9 +100,7 @@ class TDMPCPolicy(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
self._use_image = False
|
||||
self._use_env_state = False
|
||||
@@ -124,9 +120,7 @@ class TDMPCPolicy(
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(
|
||||
maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)
|
||||
),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
}
|
||||
if self._use_image:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
@@ -141,9 +135,7 @@ class TDMPCPolicy(
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self._use_image:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
@@ -217,20 +209,13 @@ class TDMPCPolicy(
|
||||
|
||||
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
||||
# trajectories.
|
||||
z = einops.repeat(
|
||||
z,
|
||||
"b d -> n b d",
|
||||
n=self.config.n_gaussian_samples + self.config.n_pi_samples,
|
||||
)
|
||||
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||
|
||||
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
|
||||
# algorithm.
|
||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||
mean = torch.zeros(
|
||||
self.config.horizon,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
device=device,
|
||||
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
|
||||
)
|
||||
# Maybe warm start CEM with the mean from the previous step.
|
||||
if self._prev_mean is not None:
|
||||
@@ -246,47 +231,35 @@ class TDMPCPolicy(
|
||||
self.config.output_shapes["action"][0],
|
||||
device=std.device,
|
||||
)
|
||||
gaussian_actions = torch.clamp(
|
||||
mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1
|
||||
)
|
||||
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
|
||||
|
||||
# Compute elite actions.
|
||||
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
|
||||
value = self.estimate_value(z, actions).nan_to_num_(0)
|
||||
elite_idxs = torch.topk(
|
||||
value, self.config.n_elites, dim=0
|
||||
).indices # (n_elites, batch)
|
||||
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
|
||||
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
|
||||
# (horizon, n_elites, batch, action_dim)
|
||||
elite_actions = actions.take_along_dim(
|
||||
einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1
|
||||
)
|
||||
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
|
||||
|
||||
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
||||
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
|
||||
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
|
||||
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
|
||||
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
|
||||
score = torch.exp(
|
||||
self.config.elite_weighting_temperature * (elite_value - max_value)
|
||||
)
|
||||
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
||||
score /= score.sum(axis=0, keepdim=True)
|
||||
# (horizon, batch, action_dim)
|
||||
_mean = torch.sum(
|
||||
einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1
|
||||
)
|
||||
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
|
||||
_std = torch.sqrt(
|
||||
torch.sum(
|
||||
einops.rearrange(score, "n b -> n b 1")
|
||||
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d"))
|
||||
** 2,
|
||||
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
# Update mean with an exponential moving average, and std with a direct replacement.
|
||||
mean = (
|
||||
self.config.gaussian_mean_momentum * mean
|
||||
+ (1 - self.config.gaussian_mean_momentum) * _mean
|
||||
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
|
||||
)
|
||||
std = _std.clamp_(self.config.min_std, self.config.max_std)
|
||||
|
||||
@@ -295,9 +268,7 @@ class TDMPCPolicy(
|
||||
|
||||
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
|
||||
# scores from the last iteration.
|
||||
actions = elite_actions[
|
||||
:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)
|
||||
]
|
||||
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
|
||||
|
||||
return actions
|
||||
|
||||
@@ -320,8 +291,7 @@ class TDMPCPolicy(
|
||||
# of the FOWM paper.
|
||||
if self.config.uncertainty_regularizer_coeff > 0:
|
||||
regularization = -(
|
||||
self.config.uncertainty_regularizer_coeff
|
||||
* self.model.Qs(z, actions[t]).std(0)
|
||||
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
|
||||
)
|
||||
else:
|
||||
regularization = 0
|
||||
@@ -341,22 +311,15 @@ class TDMPCPolicy(
|
||||
if self.config.q_ensemble_size > 2:
|
||||
G += (
|
||||
running_discount
|
||||
* torch.min(
|
||||
terminal_values[
|
||||
torch.randint(0, self.config.q_ensemble_size, size=(2,))
|
||||
],
|
||||
dim=0,
|
||||
)[0]
|
||||
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
|
||||
0
|
||||
]
|
||||
)
|
||||
else:
|
||||
G += running_discount * torch.min(terminal_values, dim=0)[0]
|
||||
# Finally, also regularize the terminal value.
|
||||
if self.config.uncertainty_regularizer_coeff > 0:
|
||||
G -= (
|
||||
running_discount
|
||||
* self.config.uncertainty_regularizer_coeff
|
||||
* terminal_values.std(0)
|
||||
)
|
||||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||
return G
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
@@ -368,9 +331,7 @@ class TDMPCPolicy(
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self._use_image:
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
@@ -388,10 +349,7 @@ class TDMPCPolicy(
|
||||
# Apply random image augmentations.
|
||||
if self._use_image and self.config.max_random_shift_ratio > 0:
|
||||
observations["observation.image"] = flatten_forward_unflatten(
|
||||
partial(
|
||||
random_shifts_aug,
|
||||
max_random_shift_ratio=self.config.max_random_shift_ratio,
|
||||
),
|
||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
observations["observation.image"],
|
||||
)
|
||||
|
||||
@@ -409,20 +367,14 @@ class TDMPCPolicy(
|
||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||
# gives us a next `z`.
|
||||
batch_size = batch["index"].shape[0]
|
||||
z_preds = torch.empty(
|
||||
horizon + 1, batch_size, self.config.latent_dim, device=device
|
||||
)
|
||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||
z_preds[0] = self.model.encode(current_observation)
|
||||
reward_preds = torch.empty_like(reward, device=device)
|
||||
for t in range(horizon):
|
||||
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(
|
||||
z_preds[t], action[t]
|
||||
)
|
||||
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
|
||||
|
||||
# Compute Q and V value predictions based on the latent rollout.
|
||||
q_preds_ensemble = self.model.Qs(
|
||||
z_preds[:-1], action
|
||||
) # (ensemble, horizon, batch)
|
||||
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
|
||||
v_preds = self.model.V(z_preds[:-1])
|
||||
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
|
||||
|
||||
@@ -436,14 +388,10 @@ class TDMPCPolicy(
|
||||
# actions (not actions estimated by π).
|
||||
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
|
||||
# and the FOWM paper.
|
||||
q_targets = reward + self.config.discount * self.model.V(
|
||||
self.model.encode(next_observations)
|
||||
)
|
||||
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
|
||||
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
|
||||
# are using them to compute loss for V.
|
||||
v_targets = self.model_target.Qs(
|
||||
z_preds[:-1].detach(), action, return_min=True
|
||||
)
|
||||
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
|
||||
|
||||
# Compute losses.
|
||||
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
|
||||
@@ -486,9 +434,7 @@ class TDMPCPolicy(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(
|
||||
q_preds_ensemble,
|
||||
einops.repeat(
|
||||
q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]
|
||||
),
|
||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
@@ -526,14 +472,12 @@ class TDMPCPolicy(
|
||||
z_preds = z_preds.detach()
|
||||
# Use stopgrad for the advantage calculation.
|
||||
with torch.no_grad():
|
||||
advantage = self.model_target.Qs(
|
||||
z_preds[:-1], action, return_min=True
|
||||
) - self.model.V(z_preds[:-1])
|
||||
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
|
||||
z_preds[:-1]
|
||||
)
|
||||
info["advantage"] = advantage[0]
|
||||
# (t, b)
|
||||
exp_advantage = torch.clamp(
|
||||
torch.exp(advantage * self.config.advantage_scaling), max=100.0
|
||||
)
|
||||
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
|
||||
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
|
||||
# Calculate the MSE between the actions and the action predictions.
|
||||
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
|
||||
@@ -588,9 +532,7 @@ class TDMPCPolicy(
|
||||
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
|
||||
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
|
||||
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
|
||||
update_ema_parameters(
|
||||
self.model_target, self.model, self.config.target_model_momentum
|
||||
)
|
||||
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
||||
|
||||
|
||||
class TDMPCTOLD(nn.Module):
|
||||
@@ -601,9 +543,7 @@ class TDMPCTOLD(nn.Module):
|
||||
self.config = config
|
||||
self._encoder = TDMPCObservationEncoder(config)
|
||||
self._dynamics = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.latent_dim + config.output_shapes["action"][0], config.mlp_dim
|
||||
),
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -614,9 +554,7 @@ class TDMPCTOLD(nn.Module):
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
self._reward = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.latent_dim + config.output_shapes["action"][0], config.mlp_dim
|
||||
),
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Mish(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -636,10 +574,7 @@ class TDMPCTOLD(nn.Module):
|
||||
self._Qs = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.Linear(
|
||||
config.latent_dim + config.output_shapes["action"][0],
|
||||
config.mlp_dim,
|
||||
),
|
||||
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -684,9 +619,7 @@ class TDMPCTOLD(nn.Module):
|
||||
m[-1], nn.Linear
|
||||
), "Sanity check. The last linear layer needs 0 initialization on weights."
|
||||
nn.init.zeros_(m[-1].weight)
|
||||
nn.init.zeros_(
|
||||
m[-1].bias
|
||||
) # this has already been done, but keep this line here for good measure
|
||||
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
|
||||
|
||||
def encode(self, obs: dict[str, Tensor]) -> Tensor:
|
||||
"""Encodes an observation into its latent representation."""
|
||||
@@ -784,32 +717,14 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
if "observation.image" in config.input_shapes:
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
config.input_shapes["observation.image"][0],
|
||||
config.image_encoder_hidden_dim,
|
||||
7,
|
||||
stride=2,
|
||||
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
config.image_encoder_hidden_dim,
|
||||
config.image_encoder_hidden_dim,
|
||||
5,
|
||||
stride=2,
|
||||
),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
config.image_encoder_hidden_dim,
|
||||
config.image_encoder_hidden_dim,
|
||||
3,
|
||||
stride=2,
|
||||
),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
config.image_encoder_hidden_dim,
|
||||
config.image_encoder_hidden_dim,
|
||||
3,
|
||||
stride=2,
|
||||
),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
@@ -825,10 +740,7 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
)
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.input_shapes["observation.state"][0],
|
||||
config.state_encoder_hidden_dim,
|
||||
),
|
||||
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -837,8 +749,7 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0],
|
||||
config.state_encoder_hidden_dim,
|
||||
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
@@ -855,15 +766,9 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
feat = []
|
||||
# NOTE: Order of observations matters here.
|
||||
if "observation.image" in self.config.input_shapes:
|
||||
feat.append(
|
||||
flatten_forward_unflatten(
|
||||
self.image_enc_layers, obs_dict["observation.image"]
|
||||
)
|
||||
)
|
||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(
|
||||
self.env_state_enc_layers(obs_dict["observation.environment_state"])
|
||||
)
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
@@ -906,17 +811,12 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
|
||||
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
|
||||
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
|
||||
for (n_p_ema, p_ema), (n_p, p) in zip(
|
||||
ema_module.named_parameters(recurse=False),
|
||||
module.named_parameters(recurse=False),
|
||||
strict=True,
|
||||
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
|
||||
):
|
||||
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
|
||||
if isinstance(p, dict):
|
||||
raise RuntimeError("Dict parameter not supported")
|
||||
if (
|
||||
isinstance(module, nn.modules.batchnorm._BatchNorm)
|
||||
or not p.requires_grad
|
||||
):
|
||||
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
|
||||
# Copy BatchNorm parameters, and non-trainable parameters directly.
|
||||
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
|
||||
with torch.no_grad():
|
||||
@@ -924,9 +824,7 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
|
||||
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
|
||||
|
||||
|
||||
def flatten_forward_unflatten(
|
||||
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
|
||||
) -> Tensor:
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
|
||||
193
lerobot/common/policies/tdmpc2/configuration_tdmpc2.py
Normal file
193
lerobot/common/policies/tdmpc2/configuration_tdmpc2.py
Normal file
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class TDMPC2Config:
|
||||
"""Configuration class for TDMPC2Policy.
|
||||
|
||||
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
|
||||
camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
|
||||
|
||||
Args:
|
||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||
action repeats in Q-learning or ask your favorite chatbot)
|
||||
horizon: Horizon for model predictive control.
|
||||
n_action_steps: Number of action steps to take from the plan given by model predictive control. This
|
||||
is an alternative to using action repeats. If this is set to more than 1, then we require
|
||||
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
|
||||
approach of using multiple steps from the plan is not in the original implementation.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
|
||||
match the original implementation.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
|
||||
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
|
||||
normalization mode here.
|
||||
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
|
||||
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
|
||||
latent_dim: Observation's latent embedding dimension.
|
||||
q_ensemble_size: Number of Q function estimators to use in an ensemble for uncertainty estimation.
|
||||
mlp_dim: Hidden dimension of MLPs used for modelling the dynamics encoder, reward function, policy
|
||||
(π), Q ensemble, and V.
|
||||
discount: Discount factor (γ) to use for the reinforcement learning formalism.
|
||||
use_mpc: Whether to use model predictive control. The alternative is to just sample the policy model
|
||||
(π) for each step.
|
||||
cem_iterations: Number of iterations for the MPPI/CEM loop in MPC.
|
||||
max_std: Maximum standard deviation for actions sampled from the gaussian PDF in CEM.
|
||||
min_std: Minimum standard deviation for noise applied to actions sampled from the policy model (π).
|
||||
Doubles up as the minimum standard deviation for actions sampled from the gaussian PDF in CEM.
|
||||
n_gaussian_samples: Number of samples to draw from the gaussian distribution every CEM iteration. Must
|
||||
be non-zero.
|
||||
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
|
||||
be zero.
|
||||
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
|
||||
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
|
||||
elites, when updating the gaussian parameters for CEM.
|
||||
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
|
||||
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
|
||||
is applied. Note that the input images are assumed to be square for this augmentation.
|
||||
reward_coeff: Loss weighting coefficient for the reward regression loss.
|
||||
value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state
|
||||
value (V) expectile regression loss.
|
||||
consistency_coeff: Loss weighting coefficient for the consistency loss.
|
||||
temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time-
|
||||
steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the
|
||||
current time step.
|
||||
target_model_momentum: Momentum (α) used for EMA updates of the target models. Updates are calculated
|
||||
as ϕ ← αϕ + (1-α)θ where ϕ are the parameters of the target model and θ are the parameters of the
|
||||
model being trained.
|
||||
"""
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: int = 1
|
||||
horizon: int = 3
|
||||
n_action_steps: int = 1
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 84, 84],
|
||||
"observation.state": [4],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [4],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] | None = None
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"},
|
||||
)
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
image_encoder_hidden_dim: int = 32
|
||||
state_encoder_hidden_dim: int = 256
|
||||
latent_dim: int = 512
|
||||
q_ensemble_size: int = 5
|
||||
num_enc_layers: int = 2
|
||||
mlp_dim: int = 512
|
||||
# Reinforcement learning.
|
||||
discount: float = 0.9
|
||||
simnorm_dim: int = 8
|
||||
dropout: float = 0.01
|
||||
|
||||
# actor
|
||||
log_std_min: float = -10
|
||||
log_std_max: float = 2
|
||||
|
||||
# critic
|
||||
num_bins: int = 101
|
||||
vmin: int = -10
|
||||
vmax: int = +10
|
||||
|
||||
# Inference.
|
||||
use_mpc: bool = True
|
||||
cem_iterations: int = 6
|
||||
max_std: float = 2.0
|
||||
min_std: float = 0.05
|
||||
n_gaussian_samples: int = 512
|
||||
n_pi_samples: int = 24
|
||||
n_elites: int = 64
|
||||
elite_weighting_temperature: float = 0.5
|
||||
|
||||
# Training and loss computation.
|
||||
max_random_shift_ratio: float = 0.0476
|
||||
# Loss coefficients.
|
||||
reward_coeff: float = 0.1
|
||||
value_coeff: float = 0.1
|
||||
consistency_coeff: float = 20.0
|
||||
entropy_coef: float = 1e-4
|
||||
temporal_decay_coeff: float = 0.5
|
||||
# Target model. NOTE (michel_aractingi) this is equivelant to
|
||||
# 1 - target_model_momentum of our TD-MPC1 implementation because
|
||||
# of the use of `torch.lerp`
|
||||
target_model_momentum: float = 0.01
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if len(image_keys) > 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} handles at most one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
if len(image_keys) > 0:
|
||||
image_key = next(iter(image_keys))
|
||||
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||
)
|
||||
if self.n_gaussian_samples <= 0:
|
||||
raise ValueError(
|
||||
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||
)
|
||||
if self.output_normalization_modes != {"action": "min_max"}:
|
||||
raise ValueError(
|
||||
"TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly "
|
||||
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
|
||||
"information."
|
||||
)
|
||||
if self.n_action_steps > 1:
|
||||
if self.n_action_repeats != 1:
|
||||
raise ValueError(
|
||||
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
|
||||
)
|
||||
if not self.use_mpc:
|
||||
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||
834
lerobot/common/policies/tdmpc2/modeling_tdmpc2.py
Normal file
834
lerobot/common/policies/tdmpc2/modeling_tdmpc2.py
Normal file
@@ -0,0 +1,834 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Nicklas Hansen and The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implementation of TD-MPC2: Scalable, Robust World Models for Continuous Control
|
||||
|
||||
We refer to the main paper and codebase:
|
||||
TD-MPC2 paper: (https://arxiv.org/abs/2310.16828)
|
||||
TD-MPC2 code: (https://github.com/nicklashansen/tdmpc2)
|
||||
"""
|
||||
|
||||
# ruff: noqa: N806
|
||||
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
|
||||
from lerobot.common.policies.tdmpc2.tdmpc2_utils import (
|
||||
NormedLinear,
|
||||
SimNorm,
|
||||
gaussian_logprob,
|
||||
soft_cross_entropy,
|
||||
squash,
|
||||
two_hot_inv,
|
||||
)
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
|
||||
|
||||
class TDMPC2Policy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "tdmpc2"],
|
||||
):
|
||||
"""Implementation of TD-MPC2 learning + inference."""
|
||||
|
||||
name = "tdmpc2"
|
||||
|
||||
def __init__(
|
||||
self, config: TDMPC2Config | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = TDMPC2Config()
|
||||
self.config = config
|
||||
self.model = TDMPC2WorldModel(config)
|
||||
# TODO (michel-aractingi) temp fix for gpu
|
||||
self.model = self.model.to("cuda:0")
|
||||
|
||||
if config.input_normalization_modes is not None:
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
self._use_image = False
|
||||
self._use_env_state = False
|
||||
if len(image_keys) > 0:
|
||||
assert len(image_keys) == 1
|
||||
self._use_image = True
|
||||
self.input_image_key = image_keys[0]
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
|
||||
self.scale = RunningScale(self.config.target_model_momentum)
|
||||
self.discount = (
|
||||
self.config.discount
|
||||
) # TODO (michel-aractingi) downscale discount according to episode length
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be
|
||||
called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
}
|
||||
if self._use_image:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
if self._use_env_state:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self._use_image:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
# When the action queue is depleted, populate it again by querying the policy.
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
|
||||
# Remove the time dimensions as it is not handled yet.
|
||||
for key in batch:
|
||||
assert batch[key].shape[1] == 1
|
||||
batch[key] = batch[key][:, 0]
|
||||
|
||||
# NOTE: Order of observations matters here.
|
||||
encode_keys = []
|
||||
if self._use_image:
|
||||
encode_keys.append("observation.image")
|
||||
if self._use_env_state:
|
||||
encode_keys.append("observation.environment_state")
|
||||
encode_keys.append("observation.state")
|
||||
z = self.model.encode({k: batch[k] for k in encode_keys})
|
||||
if self.config.use_mpc: # noqa: SIM108
|
||||
actions = self.plan(z) # (horizon, batch, action_dim)
|
||||
else:
|
||||
# Plan with the policy (π) alone. This always returns one action so unsqueeze to get a
|
||||
# sequence dimension like in the MPC branch.
|
||||
actions = self.model.pi(z)[0].unsqueeze(0)
|
||||
|
||||
actions = torch.clamp(actions, -1, +1)
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.n_action_repeats > 1:
|
||||
for _ in range(self.config.n_action_repeats):
|
||||
self._queues["action"].append(actions[0])
|
||||
else:
|
||||
# Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action.
|
||||
self._queues["action"].extend(actions[: self.config.n_action_steps])
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
def plan(self, z: Tensor) -> Tensor:
|
||||
"""Plan sequence of actions using TD-MPC inference.
|
||||
|
||||
Args:
|
||||
z: (batch, latent_dim,) tensor for the initial state.
|
||||
Returns:
|
||||
(horizon, batch, action_dim,) tensor for the planned trajectory of actions.
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch_size = z.shape[0]
|
||||
|
||||
# Sample Nπ trajectories from the policy.
|
||||
pi_actions = torch.empty(
|
||||
self.config.horizon,
|
||||
self.config.n_pi_samples,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
device=device,
|
||||
)
|
||||
if self.config.n_pi_samples > 0:
|
||||
_z = einops.repeat(z, "b d -> n b d", n=self.config.n_pi_samples)
|
||||
for t in range(self.config.horizon):
|
||||
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
|
||||
# helpful for CEM.
|
||||
pi_actions[t] = self.model.pi(_z)[0]
|
||||
_z = self.model.latent_dynamics(_z, pi_actions[t])
|
||||
|
||||
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
||||
# trajectories.
|
||||
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||
|
||||
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
|
||||
# algorithm.
|
||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||
mean = torch.zeros(
|
||||
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
|
||||
)
|
||||
# Maybe warm start CEM with the mean from the previous step.
|
||||
if self._prev_mean is not None:
|
||||
mean[:-1] = self._prev_mean[1:]
|
||||
std = self.config.max_std * torch.ones_like(mean)
|
||||
|
||||
for _ in range(self.config.cem_iterations):
|
||||
# Randomly sample action trajectories for the gaussian distribution.
|
||||
std_normal_noise = torch.randn(
|
||||
self.config.horizon,
|
||||
self.config.n_gaussian_samples,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
device=std.device,
|
||||
)
|
||||
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
|
||||
|
||||
# Compute elite actions.
|
||||
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
|
||||
value = self.estimate_value(z, actions).nan_to_num_(0).squeeze()
|
||||
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
|
||||
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
|
||||
# (horizon, n_elites, batch, action_dim)
|
||||
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
|
||||
|
||||
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
||||
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
|
||||
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
|
||||
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
|
||||
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
|
||||
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
||||
score /= score.sum(axis=0, keepdim=True)
|
||||
# (horizon, batch, action_dim)
|
||||
mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (
|
||||
einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9
|
||||
)
|
||||
std = torch.sqrt(
|
||||
torch.sum(
|
||||
einops.rearrange(score, "n b -> n b 1")
|
||||
* (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2,
|
||||
dim=1,
|
||||
)
|
||||
/ (einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9)
|
||||
).clamp_(self.config.min_std, self.config.max_std)
|
||||
|
||||
# Keep track of the mean for warm-starting subsequent steps.
|
||||
self._prev_mean = mean
|
||||
|
||||
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
|
||||
# scores from the last iteration.
|
||||
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_value(self, z: Tensor, actions: Tensor):
|
||||
"""Estimates the value of a trajectory as per eqn 4 of the FOWM paper.
|
||||
|
||||
Args:
|
||||
z: (batch, latent_dim) tensor of initial latent states.
|
||||
actions: (horizon, batch, action_dim) tensor of action trajectories.
|
||||
Returns:
|
||||
(batch,) tensor of values.
|
||||
"""
|
||||
# Initialize return and running discount factor.
|
||||
G, running_discount = 0, 1
|
||||
# Iterate over the actions in the trajectory to simulate the trajectory using the latent dynamics
|
||||
# model. Keep track of return.
|
||||
for t in range(actions.shape[0]):
|
||||
# Estimate the next state (latent) and reward.
|
||||
z, reward = self.model.latent_dynamics_and_reward(z, actions[t], discretize_reward=True)
|
||||
# Update the return and running discount.
|
||||
G += running_discount * reward
|
||||
running_discount *= self.config.discount
|
||||
|
||||
# next_action = self.model.pi(z)[0] # (batch, action_dim)
|
||||
# terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
|
||||
|
||||
return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type="avg")
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self._use_image:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
||||
# (b, t) -> (t, b)
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"] # (t, b, action_dim)
|
||||
reward = batch["next.reward"] # (t, b)
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
|
||||
# Apply random image augmentations.
|
||||
if self._use_image and self.config.max_random_shift_ratio > 0:
|
||||
observations["observation.image"] = flatten_forward_unflatten(
|
||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
observations["observation.image"],
|
||||
)
|
||||
|
||||
# Get the current observation for predicting trajectories, and all future observations for use in
|
||||
# the latent consistency loss and TD loss.
|
||||
current_observation, next_observations = {}, {}
|
||||
for k in observations:
|
||||
current_observation[k] = observations[k][0]
|
||||
next_observations[k] = observations[k][1:]
|
||||
horizon, batch_size = next_observations[
|
||||
"observation.image" if self._use_image else "observation.environment_state"
|
||||
].shape[:2]
|
||||
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||
# gives us a next `z`.
|
||||
batch_size = batch["index"].shape[0]
|
||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||
z_preds[0] = self.model.encode(current_observation)
|
||||
reward_preds = torch.empty(horizon, batch_size, self.config.num_bins, device=device)
|
||||
for t in range(horizon):
|
||||
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
|
||||
|
||||
# Compute Q value predictions based on the latent rollout.
|
||||
q_preds_ensemble = self.model.Qs(
|
||||
z_preds[:-1], action, return_type="all"
|
||||
) # (ensemble, horizon, batch)
|
||||
info.update({"Q": q_preds_ensemble.mean().item()})
|
||||
|
||||
# Compute various targets with stopgrad.
|
||||
with torch.no_grad():
|
||||
# Latent state consistency targets for consistency loss.
|
||||
z_targets = self.model.encode(next_observations)
|
||||
|
||||
# Compute the TD-target from a reward and the next observation
|
||||
pi = self.model.pi(z_targets)[0]
|
||||
td_targets = (
|
||||
reward
|
||||
+ self.config.discount
|
||||
* self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
|
||||
)
|
||||
|
||||
# Compute losses.
|
||||
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
|
||||
# future have less impact on the loss. Note: unsqueeze will let us broadcast to (seq, batch).
|
||||
temporal_loss_coeffs = torch.pow(
|
||||
self.config.temporal_decay_coeff, torch.arange(horizon, device=device)
|
||||
).unsqueeze(-1)
|
||||
|
||||
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
|
||||
# predicted from the (target model's) observation encoder.
|
||||
consistency_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
|
||||
# `z_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# `z_targets` depends on the next observation.
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
|
||||
# rewards.
|
||||
reward_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* soft_cross_entropy(reward_preds, reward, self.config).mean(1)
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
ce_value_loss = 0.0
|
||||
for i in range(self.config.q_ensemble_size):
|
||||
ce_value_loss += soft_cross_entropy(q_preds_ensemble[i], td_targets, self.config).mean(1)
|
||||
|
||||
q_value_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* ce_value_loss
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
|
||||
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
||||
# We won't need these gradients again so detach.
|
||||
z_preds = z_preds.detach()
|
||||
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
|
||||
|
||||
with torch.no_grad():
|
||||
# avoid unnessecary computation of the gradients during policy optimization
|
||||
# TODO (michel-aractingi): the same logic should be extended when adding task embeddings
|
||||
qs = self.model.Qs(z_preds[:-1], action_preds, return_type="avg")
|
||||
self.scale.update(qs[0])
|
||||
qs = self.scale(qs)
|
||||
|
||||
pi_loss = (
|
||||
(self.config.entropy_coef * log_pis - qs).mean(dim=2)
|
||||
* temporal_loss_coeffs
|
||||
# `action_preds` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).mean()
|
||||
|
||||
loss = (
|
||||
self.config.consistency_coeff * consistency_loss
|
||||
+ self.config.reward_coeff * reward_loss
|
||||
+ self.config.value_coeff * q_value_loss
|
||||
+ pi_loss
|
||||
)
|
||||
|
||||
info.update(
|
||||
{
|
||||
"consistency_loss": consistency_loss.item(),
|
||||
"reward_loss": reward_loss.item(),
|
||||
"Q_value_loss": q_value_loss.item(),
|
||||
"pi_loss": pi_loss.item(),
|
||||
"loss": loss,
|
||||
"sum_loss": loss.item() * self.config.horizon,
|
||||
"pi_scale": float(self.scale.value),
|
||||
}
|
||||
)
|
||||
|
||||
# Undo (b, t) -> (t, b).
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
return info
|
||||
|
||||
def update(self):
|
||||
"""Update the target model's using polyak averaging."""
|
||||
self.model.update_target_Q()
|
||||
|
||||
|
||||
class TDMPC2WorldModel(nn.Module):
|
||||
"""Latent dynamics model used in TD-MPC2."""
|
||||
|
||||
def __init__(self, config: TDMPC2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self._encoder = TDMPC2ObservationEncoder(config)
|
||||
|
||||
# Define latent dynamics head
|
||||
self._dynamics = nn.Sequential(
|
||||
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)),
|
||||
)
|
||||
|
||||
# Define reward head
|
||||
self._reward = nn.Sequential(
|
||||
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||
nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
|
||||
)
|
||||
|
||||
# Define policy head
|
||||
self._pi = nn.Sequential(
|
||||
NormedLinear(config.latent_dim, config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||
nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0]),
|
||||
)
|
||||
|
||||
# Define ensemble of Q functions
|
||||
self._Qs = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
NormedLinear(
|
||||
config.latent_dim + config.output_shapes["action"][0],
|
||||
config.mlp_dim,
|
||||
dropout=config.dropout,
|
||||
),
|
||||
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||
nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
|
||||
)
|
||||
for _ in range(config.q_ensemble_size)
|
||||
]
|
||||
)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
|
||||
|
||||
self.log_std_min = torch.tensor(config.log_std_min)
|
||||
self.log_std_dif = torch.tensor(config.log_std_max) - self.log_std_min
|
||||
|
||||
self.bins = torch.linspace(config.vmin, config.vmax, config.num_bins)
|
||||
self.config.bin_size = (config.vmax - config.vmin) / (config.num_bins - 1)
|
||||
|
||||
def _init_weights(self):
|
||||
"""Initialize model weights.
|
||||
Custom weight initializations proposed in TD-MPC2.
|
||||
|
||||
"""
|
||||
|
||||
def _apply_fn(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.ParameterList):
|
||||
for i, p in enumerate(m):
|
||||
if p.dim() == 3: # Linear
|
||||
nn.init.trunc_normal_(p, std=0.02) # Weight
|
||||
nn.init.constant_(m[i + 1], 0) # Bias
|
||||
|
||||
self.apply(_apply_fn)
|
||||
|
||||
# initialize parameters of the
|
||||
for m in [self._reward, *self._Qs]:
|
||||
assert isinstance(
|
||||
m[-1], nn.Linear
|
||||
), "Sanity check. The last linear layer needs 0 initialization on weights."
|
||||
nn.init.zeros_(m[-1].weight)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""
|
||||
Overriding `to` method to also move additional tensors to device.
|
||||
"""
|
||||
super().to(*args, **kwargs)
|
||||
self.log_std_min = self.log_std_min.to(*args, **kwargs)
|
||||
self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
|
||||
self.bins = self.bins.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def train(self, mode):
|
||||
super().train(mode)
|
||||
self._target_Qs.train(False)
|
||||
return self
|
||||
|
||||
def encode(self, obs: dict[str, Tensor]) -> Tensor:
|
||||
"""Encodes an observation into its latent representation."""
|
||||
return self._encoder(obs)
|
||||
|
||||
def latent_dynamics_and_reward(
|
||||
self, z: Tensor, a: Tensor, discretize_reward: bool = False
|
||||
) -> tuple[Tensor, Tensor, bool]:
|
||||
"""Predict the next state's latent representation and the reward given a current latent and action.
|
||||
|
||||
Args:
|
||||
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||
a: (*, action_dim) tensor for the action to be applied.
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- (*, latent_dim) tensor for the next state's latent representation.
|
||||
- (*,) tensor for the estimated reward.
|
||||
"""
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
reward = self._reward(x).squeeze(-1)
|
||||
if discretize_reward:
|
||||
reward = two_hot_inv(reward, self.bins)
|
||||
return self._dynamics(x), reward
|
||||
|
||||
def latent_dynamics(self, z: Tensor, a: Tensor) -> Tensor:
|
||||
"""Predict the next state's latent representation given a current latent and action.
|
||||
|
||||
Args:
|
||||
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||
a: (*, action_dim) tensor for the action to be applied.
|
||||
Returns:
|
||||
(*, latent_dim) tensor for the next state's latent representation.
|
||||
"""
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
return self._dynamics(x)
|
||||
|
||||
def pi(self, z: Tensor) -> Tensor:
|
||||
"""Samples an action from the learned policy.
|
||||
|
||||
The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when
|
||||
generating rollouts for online training.
|
||||
|
||||
Args:
|
||||
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||
std: The standard deviation of the injected noise.
|
||||
Returns:
|
||||
(*, action_dim) tensor for the sampled action.
|
||||
"""
|
||||
mu, log_std = self._pi(z).chunk(2, dim=-1)
|
||||
log_std = self.log_std_min + 0.5 * self.log_std_dif * (torch.tanh(log_std) + 1)
|
||||
eps = torch.randn_like(mu)
|
||||
|
||||
log_pi = gaussian_logprob(eps, log_std)
|
||||
pi = mu + eps * log_std.exp()
|
||||
mu, pi, log_pi = squash(mu, pi, log_pi)
|
||||
|
||||
return pi, mu, log_pi, log_std
|
||||
|
||||
def Qs(self, z: Tensor, a: Tensor, return_type: str = "min", target=False) -> Tensor: # noqa: N802
|
||||
"""Predict state-action value for all of the learned Q functions.
|
||||
|
||||
Args:
|
||||
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||
a: (*, action_dim) tensor for the action to be applied.
|
||||
return_type: either 'min' or 'all' otherwise the average is returned
|
||||
Returns:
|
||||
(q_ensemble, *) tensor for the value predictions of each learned Q function in the ensemble or the average or min
|
||||
"""
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
|
||||
if target:
|
||||
out = torch.stack([q(x).squeeze(-1) for q in self._target_Qs], dim=0)
|
||||
else:
|
||||
out = torch.stack([q(x).squeeze(-1) for q in self._Qs], dim=0)
|
||||
|
||||
if return_type == "all":
|
||||
return out
|
||||
|
||||
Q1, Q2 = out[np.random.choice(len(self._Qs), size=2, replace=False)]
|
||||
Q1, Q2 = two_hot_inv(Q1, self.bins), two_hot_inv(Q2, self.bins)
|
||||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||
|
||||
def update_target_Q(self):
|
||||
"""
|
||||
Soft-update target Q-networks using Polyak averaging.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters(), strict=False):
|
||||
p_target.data.lerp_(p.data, self.config.target_model_momentum)
|
||||
|
||||
|
||||
class TDMPC2ObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: TDMPC2Config):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
TODO(alexander-soare): The original work allows for multiple images by concatenating them along the
|
||||
channel dimension. Re-implement this capability.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Define the observation encoder whether its pixels or states
|
||||
encoder_dict = {}
|
||||
for obs_key in config.input_shapes:
|
||||
if "observation.image" in config.input_shapes:
|
||||
encoder_module = nn.Sequential(
|
||||
nn.Conv2d(config.input_shapes[obs_key][0], config.image_encoder_hidden_dim, 7, stride=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=1),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes[obs_key])
|
||||
with torch.inference_mode():
|
||||
out_shape = encoder_module(dummy_batch).shape[1:]
|
||||
encoder_module.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
NormedLinear(np.prod(out_shape), config.latent_dim, act=SimNorm(config.simnorm_dim)),
|
||||
)
|
||||
)
|
||||
|
||||
elif (
|
||||
"observation.state" in config.input_shapes
|
||||
or "observation.environment_state" in config.input_shapes
|
||||
):
|
||||
encoder_module = nn.ModuleList()
|
||||
encoder_module.append(
|
||||
NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)
|
||||
)
|
||||
assert config.num_enc_layers > 0
|
||||
for _ in range(config.num_enc_layers - 1):
|
||||
encoder_module.append(
|
||||
NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)
|
||||
)
|
||||
encoder_module.append(
|
||||
NormedLinear(
|
||||
config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)
|
||||
)
|
||||
)
|
||||
encoder_module = nn.Sequential(*encoder_module)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"No corresponding encoder module for key {obs_key}.")
|
||||
|
||||
encoder_dict[obs_key.replace(".", "")] = encoder_module
|
||||
|
||||
self.encoder = nn.ModuleDict(encoder_dict)
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
over all features.
|
||||
"""
|
||||
feat = []
|
||||
for obs_key in self.config.input_shapes:
|
||||
if "observation.image" in obs_key:
|
||||
feat.append(
|
||||
flatten_forward_unflatten(self.encoder[obs_key.replace(".", "")], obs_dict[obs_key])
|
||||
)
|
||||
else:
|
||||
feat.append(self.encoder[obs_key.replace(".", "")](obs_dict[obs_key]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
|
||||
def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor:
|
||||
"""Randomly shifts images horizontally and vertically.
|
||||
|
||||
Adapted from https://github.com/facebookresearch/drqv2
|
||||
"""
|
||||
b, _, h, w = x.size()
|
||||
assert h == w, "non-square images not handled yet"
|
||||
pad = int(round(max_random_shift_ratio * h))
|
||||
x = F.pad(x, tuple([pad] * 4), "replicate")
|
||||
eps = 1.0 / (h + 2 * pad)
|
||||
arange = torch.linspace(
|
||||
-1.0 + eps,
|
||||
1.0 - eps,
|
||||
h + 2 * pad,
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)[:h]
|
||||
arange = einops.repeat(arange, "w -> h w 1", h=h)
|
||||
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
|
||||
base_grid = einops.repeat(base_grid, "h w c -> b h w c", b=b)
|
||||
# A random shift in units of pixels and within the boundaries of the padding.
|
||||
shift = torch.randint(
|
||||
0,
|
||||
2 * pad + 1,
|
||||
size=(b, 1, 1, 2),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
shift *= 2.0 / (h + 2 * pad)
|
||||
grid = base_grid + shift
|
||||
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
|
||||
|
||||
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||
(B, *), where * is any number of dimensions.
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions, generally
|
||||
different from *.
|
||||
Returns:
|
||||
A return value from the callable reshaped to (**, *).
|
||||
"""
|
||||
if image_tensor.ndim == 4:
|
||||
return fn(image_tensor)
|
||||
start_dims = image_tensor.shape[:-3]
|
||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
|
||||
|
||||
class RunningScale:
|
||||
"""Running trimmed scale estimator."""
|
||||
|
||||
def __init__(self, tau):
|
||||
self.tau = tau
|
||||
self._value = torch.ones(1, dtype=torch.float32, device=torch.device("cuda"))
|
||||
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device("cuda"))
|
||||
|
||||
def state_dict(self):
|
||||
return dict(value=self._value, percentiles=self._percentiles)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._value.data.copy_(state_dict["value"])
|
||||
self._percentiles.data.copy_(state_dict["percentiles"])
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._value.cpu().item()
|
||||
|
||||
def _percentile(self, x):
|
||||
x_dtype, x_shape = x.dtype, x.shape
|
||||
x = x.view(x.shape[0], -1)
|
||||
in_sorted, _ = torch.sort(x, dim=0)
|
||||
positions = self._percentiles * (x.shape[0] - 1) / 100
|
||||
floored = torch.floor(positions)
|
||||
ceiled = floored + 1
|
||||
ceiled[ceiled > x.shape[0] - 1] = x.shape[0] - 1
|
||||
weight_ceiled = positions - floored
|
||||
weight_floored = 1.0 - weight_ceiled
|
||||
d0 = in_sorted[floored.long(), :] * weight_floored[:, None]
|
||||
d1 = in_sorted[ceiled.long(), :] * weight_ceiled[:, None]
|
||||
return (d0 + d1).view(-1, *x_shape[1:]).type(x_dtype)
|
||||
|
||||
def update(self, x):
|
||||
percentiles = self._percentile(x.detach())
|
||||
value = torch.clamp(percentiles[1] - percentiles[0], min=1.0)
|
||||
self._value.data.lerp_(value, self.tau)
|
||||
|
||||
def __call__(self, x, update=False):
|
||||
if update:
|
||||
self.update(x)
|
||||
return x * (1 / self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return f"RunningScale(S: {self.value})"
|
||||
164
lerobot/common/policies/tdmpc2/tdmpc2_utils.py
Normal file
164
lerobot/common/policies/tdmpc2/tdmpc2_utils.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functorch import combine_state_for_ensemble
|
||||
|
||||
|
||||
class Ensemble(nn.Module):
|
||||
"""
|
||||
Vectorized ensemble of modules.
|
||||
"""
|
||||
|
||||
def __init__(self, modules, **kwargs):
|
||||
super().__init__()
|
||||
modules = nn.ModuleList(modules)
|
||||
fn, params, _ = combine_state_for_ensemble(modules)
|
||||
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness="different", **kwargs)
|
||||
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
|
||||
self._repr = str(modules)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.vmap([p for p in self.params], (), *args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "Vectorized " + self._repr
|
||||
|
||||
|
||||
class SimNorm(nn.Module):
|
||||
"""
|
||||
Simplicial normalization.
|
||||
Adapted from https://arxiv.org/abs/2204.00616.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
shp = x.shape
|
||||
x = x.view(*shp[:-1], -1, self.dim)
|
||||
x = F.softmax(x, dim=-1)
|
||||
return x.view(*shp)
|
||||
|
||||
def __repr__(self):
|
||||
return f"SimNorm(dim={self.dim})"
|
||||
|
||||
|
||||
class NormedLinear(nn.Linear):
|
||||
"""
|
||||
Linear layer with LayerNorm, activation, and optionally dropout.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, dropout=0.0, act=nn.Mish(inplace=True), **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ln = nn.LayerNorm(self.out_features)
|
||||
self.act = act
|
||||
self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None
|
||||
|
||||
def forward(self, x):
|
||||
x = super().forward(x)
|
||||
if self.dropout:
|
||||
x = self.dropout(x)
|
||||
return self.act(self.ln(x))
|
||||
|
||||
def __repr__(self):
|
||||
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
|
||||
return (
|
||||
f"NormedLinear(in_features={self.in_features}, "
|
||||
f"out_features={self.out_features}, "
|
||||
f"bias={self.bias is not None}{repr_dropout}, "
|
||||
f"act={self.act.__class__.__name__})"
|
||||
)
|
||||
|
||||
|
||||
def soft_cross_entropy(pred, target, cfg):
|
||||
"""Computes the cross entropy loss between predictions and soft targets."""
|
||||
pred = F.log_softmax(pred, dim=-1)
|
||||
target = two_hot(target, cfg)
|
||||
return -(target * pred).sum(-1, keepdim=True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def log_std(x, low, dif):
|
||||
return low + 0.5 * dif * (torch.tanh(x) + 1)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _gaussian_residual(eps, log_std):
|
||||
return -0.5 * eps.pow(2) - log_std
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _gaussian_logprob(residual):
|
||||
return residual - 0.5 * torch.log(2 * torch.pi)
|
||||
|
||||
|
||||
def gaussian_logprob(eps, log_std, size=None):
|
||||
"""Compute Gaussian log probability."""
|
||||
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
|
||||
if size is None:
|
||||
size = eps.size(-1)
|
||||
return _gaussian_logprob(residual) * size
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _squash(pi):
|
||||
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
|
||||
|
||||
|
||||
def squash(mu, pi, log_pi):
|
||||
"""Apply squashing function."""
|
||||
mu = torch.tanh(mu)
|
||||
pi = torch.tanh(pi)
|
||||
log_pi -= _squash(pi).sum(-1, keepdim=True)
|
||||
return mu, pi, log_pi
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def symlog(x):
|
||||
"""
|
||||
Symmetric logarithmic function.
|
||||
Adapted from https://github.com/danijar/dreamerv3.
|
||||
"""
|
||||
return torch.sign(x) * torch.log(1 + torch.abs(x))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def symexp(x):
|
||||
"""
|
||||
Symmetric exponential function.
|
||||
Adapted from https://github.com/danijar/dreamerv3.
|
||||
"""
|
||||
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
|
||||
|
||||
|
||||
def two_hot(x, cfg):
|
||||
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
|
||||
|
||||
# x shape [horizon, num_features]
|
||||
if cfg.num_bins == 0:
|
||||
return x
|
||||
elif cfg.num_bins == 1:
|
||||
return symlog(x)
|
||||
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax)
|
||||
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() # shape [num_features]
|
||||
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) # shape [num_features , 1]
|
||||
soft_two_hot = torch.zeros(
|
||||
*x.shape, cfg.num_bins, device=x.device
|
||||
) # shape [horizon, num_features, num_bins]
|
||||
soft_two_hot.scatter_(2, bin_idx.unsqueeze(-1), 1 - bin_offset)
|
||||
soft_two_hot.scatter_(2, (bin_idx.unsqueeze(-1) + 1) % cfg.num_bins, bin_offset)
|
||||
return soft_two_hot
|
||||
|
||||
|
||||
def two_hot_inv(x, bins):
|
||||
"""Converts a batch of soft two-hot encoded vectors to scalars."""
|
||||
num_bins = bins.shape[0]
|
||||
if num_bins == 0:
|
||||
return x
|
||||
elif num_bins == 1:
|
||||
return symexp(x)
|
||||
|
||||
x = F.softmax(x, dim=-1)
|
||||
x = torch.sum(x * bins, dim=-1, keepdim=True)
|
||||
return symexp(x)
|
||||
@@ -109,9 +109,7 @@ class VQBeTConfig:
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
|
||||
@@ -79,9 +79,7 @@ class VQBeTPolicy(
|
||||
|
||||
self.vqbet = VQBeTModel(config)
|
||||
|
||||
self.expected_image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -106,12 +104,8 @@ class VQBeTPolicy(
|
||||
"""
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -122,14 +116,8 @@ class VQBeTPolicy(
|
||||
)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {
|
||||
k: torch.stack(list(self._queues[k]), dim=1)
|
||||
for k in batch
|
||||
if k in self._queues
|
||||
}
|
||||
actions = self.vqbet(batch, rollout=True)[
|
||||
:, : self.config.action_chunk_size
|
||||
]
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
||||
|
||||
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
@@ -142,12 +130,8 @@ class VQBeTPolicy(
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.expected_image_keys], dim=-4
|
||||
)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
@@ -155,9 +139,7 @@ class VQBeTPolicy(
|
||||
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
|
||||
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
|
||||
loss, n_different_codes, n_different_combinations, recon_l1_error = (
|
||||
self.vqbet.action_head.discretize(
|
||||
self.config.n_vqvae_training_steps, batch["action"]
|
||||
)
|
||||
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
|
||||
)
|
||||
return {
|
||||
"loss": loss,
|
||||
@@ -214,9 +196,7 @@ class SpatialSoftmax(nn.Module):
|
||||
|
||||
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||
# and causes a small degradation in pc_success of pre-trained models.
|
||||
pos_x, pos_y = np.meshgrid(
|
||||
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
|
||||
)
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||
# register as buffer so it's moved to the correct device.
|
||||
@@ -308,17 +288,14 @@ class VQBeTModel(nn.Module):
|
||||
self.config = config
|
||||
|
||||
self.rgb_encoder = VQBeTRgbEncoder(config)
|
||||
self.num_images = len(
|
||||
[k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
)
|
||||
self.num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
|
||||
# Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
|
||||
self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
|
||||
|
||||
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
||||
self.state_projector = MLP(
|
||||
config.input_shapes["observation.state"][0],
|
||||
hidden_channels=[self.config.gpt_input_dim],
|
||||
config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
self.rgb_feature_projector = MLP(
|
||||
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
|
||||
@@ -333,12 +310,7 @@ class VQBeTModel(nn.Module):
|
||||
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
|
||||
self.register_buffer(
|
||||
"select_target_actions_indices",
|
||||
torch.row_stack(
|
||||
[
|
||||
torch.arange(i, i + self.config.action_chunk_size)
|
||||
for i in range(num_tokens)
|
||||
]
|
||||
),
|
||||
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
|
||||
)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
|
||||
@@ -353,11 +325,7 @@ class VQBeTModel(nn.Module):
|
||||
)
|
||||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(
|
||||
img_features,
|
||||
"(b s n) ... -> b s n ...",
|
||||
b=batch_size,
|
||||
s=n_obs_steps,
|
||||
n=self.num_images,
|
||||
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
|
||||
)
|
||||
|
||||
# Arrange prior and current observation step tokens as shown in the class docstring.
|
||||
@@ -369,19 +337,13 @@ class VQBeTModel(nn.Module):
|
||||
input_tokens.append(
|
||||
self.state_projector(batch["observation.state"])
|
||||
) # (batch, obs_step, projection dims)
|
||||
input_tokens.append(
|
||||
einops.repeat(
|
||||
self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps
|
||||
)
|
||||
)
|
||||
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
|
||||
# Interleave tokens by stacking and rearranging.
|
||||
input_tokens = torch.stack(input_tokens, dim=2)
|
||||
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
|
||||
|
||||
len_additional_action_token = self.config.n_action_pred_token - 1
|
||||
future_action_tokens = self.action_token.repeat(
|
||||
batch_size, len_additional_action_token, 1
|
||||
)
|
||||
future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||
|
||||
# add additional action query tokens for predicting future action chunks
|
||||
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
|
||||
@@ -390,9 +352,9 @@ class VQBeTModel(nn.Module):
|
||||
features = self.policy(input_tokens)
|
||||
# len(self.config.input_shapes) is the number of different observation modes.
|
||||
# this line gets the index of action prompt tokens.
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (
|
||||
len(self.config.input_shapes) + 1
|
||||
) + len(self.config.input_shapes)
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len(
|
||||
self.config.input_shapes
|
||||
)
|
||||
|
||||
# only extract the output tokens at the position of action query:
|
||||
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
|
||||
@@ -400,11 +362,7 @@ class VQBeTModel(nn.Module):
|
||||
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
|
||||
if len_additional_action_token > 0:
|
||||
features = torch.cat(
|
||||
[
|
||||
features[:, historical_act_pred_index],
|
||||
features[:, -len_additional_action_token:],
|
||||
],
|
||||
dim=1,
|
||||
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
|
||||
)
|
||||
else:
|
||||
features = features[:, historical_act_pred_index]
|
||||
@@ -412,15 +370,13 @@ class VQBeTModel(nn.Module):
|
||||
action_head_output = self.action_head(features)
|
||||
# if rollout, VQ-BeT don't calculate loss
|
||||
if rollout:
|
||||
return action_head_output["predicted_action"][
|
||||
:, n_obs_steps - 1, :
|
||||
].reshape(batch_size, self.config.action_chunk_size, -1)
|
||||
return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
|
||||
batch_size, self.config.action_chunk_size, -1
|
||||
)
|
||||
# else, it calculate overall loss (bin prediction loss, and offset loss)
|
||||
else:
|
||||
output = batch["action"][:, self.select_target_actions_indices]
|
||||
loss = self.action_head.loss_fn(
|
||||
action_head_output, output, reduction="mean"
|
||||
)
|
||||
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
|
||||
return action_head_output, loss
|
||||
|
||||
|
||||
@@ -455,9 +411,7 @@ class VQBeTHead(nn.Module):
|
||||
else:
|
||||
self.map_to_cbet_preds_bin = MLP(
|
||||
in_channels=config.gpt_output_dim,
|
||||
hidden_channels=[
|
||||
self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed
|
||||
],
|
||||
hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
|
||||
)
|
||||
self.map_to_cbet_preds_offset = MLP(
|
||||
in_channels=config.gpt_output_dim,
|
||||
@@ -484,10 +438,7 @@ class VQBeTHead(nn.Module):
|
||||
|
||||
loss, metric = self.vqvae_model.vqvae_forward(actions)
|
||||
n_different_codes = sum(
|
||||
[
|
||||
len(torch.unique(metric[2][:, i]))
|
||||
for i in range(self.vqvae_model.vqvae_num_layers)
|
||||
]
|
||||
[len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)]
|
||||
)
|
||||
n_different_combinations = len(torch.unique(metric[2], dim=0))
|
||||
recon_l1_error = metric[0].detach().cpu().item()
|
||||
@@ -534,13 +485,7 @@ class VQBeTHead(nn.Module):
|
||||
|
||||
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
|
||||
torch.cat(
|
||||
(
|
||||
x,
|
||||
F.one_hot(
|
||||
sampled_primary_centers,
|
||||
num_classes=self.config.vqvae_n_embed,
|
||||
),
|
||||
),
|
||||
(x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
|
||||
axis=1,
|
||||
)
|
||||
)
|
||||
@@ -548,29 +493,19 @@ class VQBeTHead(nn.Module):
|
||||
cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
|
||||
)
|
||||
sampled_secondary_centers = einops.rearrange(
|
||||
torch.multinomial(
|
||||
cbet_secondary_probs.view(-1, choices), num_samples=1
|
||||
),
|
||||
torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1),
|
||||
"(NT) 1 -> NT",
|
||||
NT=NT,
|
||||
)
|
||||
sampled_centers = torch.stack(
|
||||
(sampled_primary_centers, sampled_secondary_centers), axis=1
|
||||
)
|
||||
cbet_logits = torch.stack(
|
||||
[cbet_primary_logits, cbet_secondary_logits], dim=1
|
||||
)
|
||||
sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1)
|
||||
cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1)
|
||||
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
|
||||
else:
|
||||
cbet_logits = self.map_to_cbet_preds_bin(x)
|
||||
cbet_logits = einops.rearrange(
|
||||
cbet_logits,
|
||||
"(NT) (G C) -> (NT) G C",
|
||||
G=self.vqvae_model.vqvae_num_layers,
|
||||
)
|
||||
cbet_probs = torch.softmax(
|
||||
cbet_logits / self.config.bet_softmax_temperature, dim=-1
|
||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
||||
)
|
||||
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
||||
NT, G, choices = cbet_probs.shape
|
||||
sampled_centers = einops.rearrange(
|
||||
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
||||
@@ -590,17 +525,9 @@ class VQBeTHead(nn.Module):
|
||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||
with torch.no_grad():
|
||||
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
|
||||
return_decoder_input = (
|
||||
self.vqvae_model.get_embeddings_from_code(sampled_centers)
|
||||
.clone()
|
||||
.detach()
|
||||
)
|
||||
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
|
||||
# pass the centroids through decoder to get actions.
|
||||
decoded_action = (
|
||||
self.vqvae_model.get_action_from_latent(return_decoder_input)
|
||||
.clone()
|
||||
.detach()
|
||||
)
|
||||
decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach()
|
||||
# reshaped extracted offset to match with decoded centroids
|
||||
sampled_offsets = einops.rearrange(
|
||||
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
|
||||
@@ -649,9 +576,7 @@ class VQBeTHead(nn.Module):
|
||||
# Figure out the loss for the actions.
|
||||
# First, we need to find the closest cluster center for each ground truth action.
|
||||
with torch.no_grad():
|
||||
state_vq, action_bins = self.vqvae_model.get_code(
|
||||
action_seq
|
||||
) # action_bins: NT, G
|
||||
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
||||
|
||||
# Now we can compute the loss.
|
||||
|
||||
@@ -674,12 +599,8 @@ class VQBeTHead(nn.Module):
|
||||
+ cbet_loss2 * self.config.secondary_code_loss_weight
|
||||
)
|
||||
|
||||
equal_primary_code_rate = torch.sum(
|
||||
(action_bins[:, 0] == sampled_centers[:, 0]).int()
|
||||
) / (NT)
|
||||
equal_secondary_code_rate = torch.sum(
|
||||
(action_bins[:, 1] == sampled_centers[:, 1]).int()
|
||||
) / (NT)
|
||||
equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
|
||||
equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
|
||||
|
||||
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
|
||||
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
|
||||
@@ -693,9 +614,7 @@ class VQBeTHead(nn.Module):
|
||||
"classification_loss": cbet_loss.detach().cpu().item(),
|
||||
"offset_loss": offset_loss.detach().cpu().item(),
|
||||
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
|
||||
"equal_secondary_code_rate": equal_secondary_code_rate.detach()
|
||||
.cpu()
|
||||
.item(),
|
||||
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
|
||||
"vq_action_error": vq_action_error.detach().cpu().item(),
|
||||
"offset_action_error": offset_action_error.detach().cpu().item(),
|
||||
"action_error_max": action_error_max.detach().cpu().item(),
|
||||
@@ -724,17 +643,11 @@ class VQBeTOptimizer(torch.optim.Adam):
|
||||
if cfg.policy.sequentially_select:
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(
|
||||
policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()
|
||||
)
|
||||
+ list(
|
||||
policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()
|
||||
)
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
||||
)
|
||||
else:
|
||||
decay_params = decay_params + list(
|
||||
policy.vqbet.action_head.map_to_cbet_preds_bin.parameters()
|
||||
)
|
||||
decay_params = decay_params + list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
||||
|
||||
optim_groups = [
|
||||
{
|
||||
@@ -780,11 +693,7 @@ class VQBeTScheduler(nn.Module):
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps)
|
||||
)
|
||||
return max(
|
||||
0.0,
|
||||
0.5
|
||||
* (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
||||
|
||||
self.lr_scheduler = LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
@@ -808,9 +717,7 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(
|
||||
config.crop_shape
|
||||
)
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -831,9 +738,7 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16, num_channels=x.num_features
|
||||
),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
)
|
||||
|
||||
# Set up pooling and final layers.
|
||||
@@ -841,25 +746,17 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.input_shapes`.
|
||||
image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
assert len(image_keys) == 1
|
||||
image_key = image_keys[0]
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape
|
||||
if config.crop_shape is not None
|
||||
else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(
|
||||
size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)
|
||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
self.pool = SpatialSoftmax(
|
||||
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
|
||||
)
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
@@ -886,9 +783,7 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module,
|
||||
predicate: Callable[[nn.Module], bool],
|
||||
func: Callable[[nn.Module], nn.Module],
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Args:
|
||||
@@ -901,11 +796,7 @@ def _replace_submodules(
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
replace_list = [
|
||||
k.split(".")
|
||||
for k, m in root_module.named_modules(remove_duplicate=True)
|
||||
if predicate(m)
|
||||
]
|
||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
for *parents, k in replace_list:
|
||||
parent_module = root_module
|
||||
if len(parents) > 0:
|
||||
@@ -920,9 +811,7 @@ def _replace_submodules(
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
assert not any(
|
||||
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
|
||||
)
|
||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||
return root_module
|
||||
|
||||
|
||||
@@ -955,8 +844,7 @@ class VqVae(nn.Module):
|
||||
)
|
||||
|
||||
self.encoder = MLP(
|
||||
in_channels=self.config.output_shapes["action"][0]
|
||||
* self.config.action_chunk_size,
|
||||
in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size,
|
||||
hidden_channels=[
|
||||
config.vqvae_enc_hidden_dim,
|
||||
config.vqvae_enc_hidden_dim,
|
||||
@@ -984,13 +872,9 @@ class VqVae(nn.Module):
|
||||
# given latent vector, this function outputs the decoded action.
|
||||
output = self.decoder(latent)
|
||||
if self.config.action_chunk_size == 1:
|
||||
return einops.rearrange(
|
||||
output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]
|
||||
)
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
else:
|
||||
return einops.rearrange(
|
||||
output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]
|
||||
)
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
|
||||
def get_code(self, state):
|
||||
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
|
||||
|
||||
@@ -123,15 +123,9 @@ class CausalSelfAttention(nn.Module):
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2)
|
||||
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
|
||||
1, 2
|
||||
) # (B, nh, T, hs)
|
||||
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
|
||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
@@ -139,9 +133,7 @@ class CausalSelfAttention(nn.Module):
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_dropout(att)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
y = (
|
||||
y.transpose(1, 2).contiguous().view(B, T, C)
|
||||
) # re-assemble all head outputs side by side
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
||||
|
||||
# output projection
|
||||
y = self.resid_dropout(self.c_proj(y))
|
||||
@@ -197,16 +189,12 @@ class GPT(nn.Module):
|
||||
"ln_f": nn.LayerNorm(config.gpt_hidden_dim),
|
||||
}
|
||||
)
|
||||
self.lm_head = nn.Linear(
|
||||
config.gpt_hidden_dim, config.gpt_output_dim, bias=False
|
||||
)
|
||||
self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False)
|
||||
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
|
||||
self.apply(self._init_weights)
|
||||
for pn, p in self.named_parameters():
|
||||
if pn.endswith("c_proj.weight"):
|
||||
torch.nn.init.normal_(
|
||||
p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)
|
||||
)
|
||||
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer))
|
||||
|
||||
# report number of parameters
|
||||
n_params = sum(p.numel() for p in self.parameters())
|
||||
@@ -220,17 +208,11 @@ class GPT(nn.Module):
|
||||
), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
||||
|
||||
# positional encodings that are added to the input embeddings
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
|
||||
0
|
||||
) # shape (1, t)
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||
|
||||
# forward the GPT model itself
|
||||
tok_emb = self.transformer.wte(
|
||||
input
|
||||
) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||
pos_emb = self.transformer.wpe(
|
||||
pos
|
||||
) # position embeddings of shape (1, t, gpt_hidden_dim)
|
||||
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
@@ -255,9 +237,7 @@ class GPT(nn.Module):
|
||||
# but want to use a smaller block size for some smaller, simpler model
|
||||
assert gpt_block_size <= self.config.gpt_block_size
|
||||
self.config.gpt_block_size = gpt_block_size
|
||||
self.transformer.wpe.weight = nn.Parameter(
|
||||
self.transformer.wpe.weight[:gpt_block_size]
|
||||
)
|
||||
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size])
|
||||
for block in self.transformer.h:
|
||||
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
|
||||
|
||||
@@ -290,9 +270,7 @@ class GPT(nn.Module):
|
||||
param_dict = dict(self.named_parameters())
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert (
|
||||
len(inter_params) == 0
|
||||
), "parameters {} made it into both decay/no_decay sets!".format(
|
||||
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
)
|
||||
assert (
|
||||
@@ -390,12 +368,8 @@ class ResidualVQ(nn.Module):
|
||||
codebook_input_dim = codebook_dim * heads
|
||||
|
||||
requires_projection = codebook_input_dim != dim
|
||||
self.project_in = (
|
||||
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_out = (
|
||||
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
|
||||
self.num_quantizers = num_quantizers
|
||||
|
||||
@@ -403,10 +377,7 @@ class ResidualVQ(nn.Module):
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
VectorQuantize(
|
||||
dim=codebook_dim,
|
||||
codebook_dim=codebook_dim,
|
||||
accept_image_fmap=accept_image_fmap,
|
||||
**kwargs,
|
||||
dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs
|
||||
)
|
||||
for _ in range(num_quantizers)
|
||||
]
|
||||
@@ -477,9 +448,7 @@ class ResidualVQ(nn.Module):
|
||||
|
||||
return all_codes
|
||||
|
||||
def forward(
|
||||
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
|
||||
):
|
||||
def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None):
|
||||
"""
|
||||
For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss.
|
||||
First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize.
|
||||
@@ -508,17 +477,13 @@ class ResidualVQ(nn.Module):
|
||||
), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
|
||||
ce_losses = []
|
||||
|
||||
should_quantize_dropout = (
|
||||
self.training and self.quantize_dropout and not return_loss
|
||||
)
|
||||
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
|
||||
|
||||
# sample a layer index at which to dropout further residual quantization
|
||||
# also prepare null indices and loss
|
||||
|
||||
if should_quantize_dropout:
|
||||
rand_quantize_dropout_index = randrange(
|
||||
self.quantize_dropout_cutoff_index, num_quant
|
||||
)
|
||||
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
|
||||
|
||||
if quant_dropout_multiple_of != 1:
|
||||
rand_quantize_dropout_index = (
|
||||
@@ -527,23 +492,14 @@ class ResidualVQ(nn.Module):
|
||||
- 1
|
||||
)
|
||||
|
||||
null_indices_shape = (
|
||||
(x.shape[0], *x.shape[-2:])
|
||||
if self.accept_image_fmap
|
||||
else tuple(x.shape[:2])
|
||||
)
|
||||
null_indices = torch.full(
|
||||
null_indices_shape, -1.0, device=device, dtype=torch.long
|
||||
)
|
||||
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
|
||||
null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long)
|
||||
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
|
||||
|
||||
# go through the layers
|
||||
|
||||
for quantizer_index, layer in enumerate(self.layers):
|
||||
if (
|
||||
should_quantize_dropout
|
||||
and quantizer_index > rand_quantize_dropout_index
|
||||
):
|
||||
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
|
||||
all_indices.append(null_indices)
|
||||
all_losses.append(null_loss)
|
||||
continue
|
||||
@@ -583,9 +539,7 @@ class ResidualVQ(nn.Module):
|
||||
|
||||
# stack all losses and indices
|
||||
|
||||
all_losses, all_indices = map(
|
||||
partial(torch.stack, dim=-1), (all_losses, all_indices)
|
||||
)
|
||||
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))
|
||||
|
||||
ret = (quantized_out, all_indices, all_losses)
|
||||
|
||||
@@ -645,12 +599,8 @@ class VectorQuantize(nn.Module):
|
||||
codebook_input_dim = codebook_dim * heads
|
||||
|
||||
requires_projection = codebook_input_dim != dim
|
||||
self.project_in = (
|
||||
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_out = (
|
||||
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
||||
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
||||
|
||||
self.eps = eps
|
||||
self.commitment_weight = commitment_weight
|
||||
@@ -664,14 +614,10 @@ class VectorQuantize(nn.Module):
|
||||
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
||||
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
||||
|
||||
assert not (
|
||||
ema_update and learnable_codebook
|
||||
), "learnable codebook not compatible with EMA update"
|
||||
assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update"
|
||||
|
||||
assert 0 <= sync_update_v <= 1.0
|
||||
assert not (
|
||||
sync_update_v > 0.0 and not learnable_codebook
|
||||
), "learnable codebook must be turned on"
|
||||
assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on"
|
||||
|
||||
self.sync_update_v = sync_update_v
|
||||
|
||||
@@ -683,9 +629,7 @@ class VectorQuantize(nn.Module):
|
||||
)
|
||||
|
||||
if sync_codebook is None:
|
||||
sync_codebook = (
|
||||
distributed.is_initialized() and distributed.get_world_size() > 1
|
||||
)
|
||||
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1
|
||||
|
||||
codebook_kwargs = {
|
||||
"dim": codebook_dim,
|
||||
@@ -850,17 +794,11 @@ class VectorQuantize(nn.Module):
|
||||
|
||||
# quantize again
|
||||
|
||||
quantize, embed_ind, distances = self._codebook(
|
||||
x, **codebook_forward_kwargs
|
||||
)
|
||||
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
|
||||
|
||||
if self.training:
|
||||
# determine code to use for commitment loss
|
||||
maybe_detach = (
|
||||
torch.detach
|
||||
if not self.learnable_codebook or freeze_codebook
|
||||
else identity
|
||||
)
|
||||
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity
|
||||
|
||||
commit_quantize = maybe_detach(quantize)
|
||||
|
||||
@@ -870,9 +808,7 @@ class VectorQuantize(nn.Module):
|
||||
|
||||
if self.sync_update_v > 0.0:
|
||||
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
|
||||
quantize = quantize + self.sync_update_v * (
|
||||
quantize - quantize.detach()
|
||||
)
|
||||
quantize = quantize + self.sync_update_v * (quantize - quantize.detach())
|
||||
|
||||
# function for calculating cross entropy loss to distance matrix
|
||||
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
|
||||
@@ -905,9 +841,7 @@ class VectorQuantize(nn.Module):
|
||||
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
|
||||
|
||||
if self.accept_image_fmap:
|
||||
embed_ind = rearrange(
|
||||
embed_ind, "b (h w) ... -> b h w ...", h=height, w=width
|
||||
)
|
||||
embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width)
|
||||
|
||||
if only_one:
|
||||
embed_ind = rearrange(embed_ind, "b 1 -> b")
|
||||
@@ -961,12 +895,8 @@ class VectorQuantize(nn.Module):
|
||||
|
||||
num_codes = codebook.shape[-2]
|
||||
|
||||
if (
|
||||
self.orthogonal_reg_max_codes is not None
|
||||
) and num_codes > self.orthogonal_reg_max_codes:
|
||||
rand_ids = torch.randperm(num_codes, device=device)[
|
||||
: self.orthogonal_reg_max_codes
|
||||
]
|
||||
if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes:
|
||||
rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes]
|
||||
codebook = codebook[:, rand_ids]
|
||||
|
||||
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
|
||||
@@ -998,9 +928,7 @@ class VectorQuantize(nn.Module):
|
||||
# if masking, only return quantized for where mask has True
|
||||
|
||||
if mask is not None:
|
||||
quantize = torch.where(
|
||||
rearrange(mask, "... -> ... 1"), quantize, orig_input
|
||||
)
|
||||
quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input)
|
||||
|
||||
return quantize, embed_ind, loss
|
||||
|
||||
@@ -1110,9 +1038,7 @@ def sample_vectors(samples, num):
|
||||
|
||||
|
||||
def batched_sample_vectors(samples, num):
|
||||
return torch.stack(
|
||||
[sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0
|
||||
)
|
||||
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0)
|
||||
|
||||
|
||||
def pad_shape(shape, size, dim=0):
|
||||
@@ -1163,9 +1089,7 @@ def sample_vectors_distributed(local_samples, num):
|
||||
all_num_samples = all_gather_sizes(local_samples, dim=0)
|
||||
|
||||
if rank == 0:
|
||||
samples_per_rank = sample_multinomial(
|
||||
num, all_num_samples / all_num_samples.sum()
|
||||
)
|
||||
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
|
||||
else:
|
||||
samples_per_rank = torch.empty_like(all_num_samples)
|
||||
|
||||
@@ -1278,9 +1202,7 @@ class EuclideanCodebook(nn.Module):
|
||||
self.eps = eps
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
self.reset_cluster_size = (
|
||||
reset_cluster_size
|
||||
if (reset_cluster_size is not None)
|
||||
else threshold_ema_dead_code
|
||||
reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code
|
||||
)
|
||||
|
||||
assert callable(gumbel_sample)
|
||||
@@ -1291,14 +1213,8 @@ class EuclideanCodebook(nn.Module):
|
||||
use_ddp and num_codebooks > 1 and kmeans_init
|
||||
), "kmeans init is not compatible with multiple codebooks in distributed environment for now"
|
||||
|
||||
self.sample_fn = (
|
||||
sample_vectors_distributed
|
||||
if use_ddp and sync_kmeans
|
||||
else batched_sample_vectors
|
||||
)
|
||||
self.kmeans_all_reduce_fn = (
|
||||
distributed.all_reduce if use_ddp and sync_kmeans else noop
|
||||
)
|
||||
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
|
||||
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
|
||||
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
|
||||
|
||||
self.register_buffer("initted", torch.Tensor([not kmeans_init]))
|
||||
@@ -1437,9 +1353,7 @@ class EuclideanCodebook(nn.Module):
|
||||
distributed.all_reduce(variance_numer)
|
||||
batch_variance = variance_numer / num_vectors
|
||||
|
||||
self.update_with_decay(
|
||||
"batch_variance", batch_variance, self.affine_param_batch_decay
|
||||
)
|
||||
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
|
||||
|
||||
def replace(self, batch_samples, batch_mask):
|
||||
for ind, (samples, mask) in enumerate(
|
||||
@@ -1448,9 +1362,7 @@ class EuclideanCodebook(nn.Module):
|
||||
if not torch.any(mask):
|
||||
continue
|
||||
|
||||
sampled = self.sample_fn(
|
||||
rearrange(samples, "... -> 1 ..."), mask.sum().item()
|
||||
)
|
||||
sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item())
|
||||
sampled = rearrange(sampled, "1 ... -> ...")
|
||||
|
||||
self.embed.data[ind][mask] = sampled
|
||||
@@ -1474,9 +1386,7 @@ class EuclideanCodebook(nn.Module):
|
||||
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
|
||||
needs_codebook_dim = x.ndim < 4
|
||||
sample_codebook_temp = (
|
||||
sample_codebook_temp
|
||||
if (sample_codebook_temp is not None)
|
||||
else self.sample_codebook_temp
|
||||
sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
|
||||
)
|
||||
|
||||
x = x.float()
|
||||
@@ -1504,9 +1414,7 @@ class EuclideanCodebook(nn.Module):
|
||||
if self.affine_param:
|
||||
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
|
||||
batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
|
||||
embed = (embed - self.codebook_mean) * (
|
||||
batch_std / codebook_std
|
||||
) + self.batch_mean
|
||||
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
|
||||
|
||||
dist = -cdist(flatten, embed)
|
||||
|
||||
@@ -1524,9 +1432,7 @@ class EuclideanCodebook(nn.Module):
|
||||
|
||||
if self.training and self.ema_update and not freeze_codebook:
|
||||
if self.affine_param:
|
||||
flatten = (flatten - self.batch_mean) * (
|
||||
codebook_std / batch_std
|
||||
) + self.codebook_mean
|
||||
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
|
||||
|
||||
if mask is not None:
|
||||
embed_onehot[~mask] = 0.0
|
||||
@@ -1549,9 +1455,7 @@ class EuclideanCodebook(nn.Module):
|
||||
self.expire_codes_(x)
|
||||
|
||||
if needs_codebook_dim:
|
||||
quantize, embed_ind = tuple(
|
||||
rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)
|
||||
)
|
||||
quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind))
|
||||
|
||||
dist = unpack_one(dist, ps, "h * d")
|
||||
|
||||
|
||||
@@ -65,9 +65,7 @@ def save_image(img_array, serial_number, frame_index, images_dir):
|
||||
img.save(str(path), quality=100)
|
||||
logging.info(f"Saved image: {path}")
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Failed to save image for camera {serial_number} frame {frame_index}: {e}"
|
||||
)
|
||||
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
|
||||
|
||||
|
||||
def save_images_from_cameras(
|
||||
@@ -96,9 +94,7 @@ def save_images_from_cameras(
|
||||
cameras = []
|
||||
for cam_sn in serial_numbers:
|
||||
print(f"{cam_sn=}")
|
||||
camera = IntelRealSenseCamera(
|
||||
cam_sn, fps=fps, width=width, height=height, mock=mock
|
||||
)
|
||||
camera = IntelRealSenseCamera(cam_sn, fps=fps, width=width, height=height, mock=mock)
|
||||
camera.connect()
|
||||
print(
|
||||
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
|
||||
@@ -144,9 +140,7 @@ def save_images_from_cameras(
|
||||
if time.perf_counter() - start_time > record_time_s:
|
||||
break
|
||||
|
||||
print(
|
||||
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
|
||||
)
|
||||
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
|
||||
|
||||
frame_index += 1
|
||||
finally:
|
||||
@@ -174,7 +168,6 @@ class IntelRealSenseCameraConfig:
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
use_depth: bool = False
|
||||
force_hardware_reset: bool = True
|
||||
rotation: int | None = None
|
||||
@@ -186,14 +179,8 @@ class IntelRealSenseCameraConfig:
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
at_least_one_is_not_none = (
|
||||
self.fps is not None or self.width is not None or self.height is not None
|
||||
)
|
||||
at_least_one_is_none = (
|
||||
self.fps is None or self.width is None or self.height is None
|
||||
)
|
||||
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
|
||||
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
|
||||
if at_least_one_is_not_none and at_least_one_is_none:
|
||||
raise ValueError(
|
||||
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
|
||||
@@ -201,9 +188,7 @@ class IntelRealSenseCameraConfig:
|
||||
)
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(
|
||||
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
|
||||
)
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
|
||||
|
||||
class IntelRealSenseCamera:
|
||||
@@ -269,7 +254,6 @@ class IntelRealSenseCamera:
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.channels = config.channels
|
||||
self.color_mode = config.color_mode
|
||||
self.use_depth = config.use_depth
|
||||
self.force_hardware_reset = config.force_hardware_reset
|
||||
@@ -298,9 +282,7 @@ class IntelRealSenseCamera:
|
||||
self.rotation = cv2.ROTATE_180
|
||||
|
||||
@classmethod
|
||||
def init_from_name(
|
||||
cls, name: str, config: IntelRealSenseCameraConfig | None = None, **kwargs
|
||||
):
|
||||
def init_from_name(cls, name: str, config: IntelRealSenseCameraConfig | None = None, **kwargs):
|
||||
camera_infos = find_cameras()
|
||||
camera_names = [cam["name"] for cam in camera_infos]
|
||||
this_name_count = Counter(camera_names)[name]
|
||||
@@ -310,9 +292,7 @@ class IntelRealSenseCamera:
|
||||
f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them."
|
||||
)
|
||||
|
||||
name_to_serial_dict = {
|
||||
cam["name"]: cam["serial_number"] for cam in camera_infos
|
||||
}
|
||||
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
|
||||
cam_sn = name_to_serial_dict[name]
|
||||
|
||||
if config is None:
|
||||
@@ -339,17 +319,13 @@ class IntelRealSenseCamera:
|
||||
|
||||
if self.fps and self.width and self.height:
|
||||
# TODO(rcadene): can we set rgb8 directly?
|
||||
config.enable_stream(
|
||||
rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps
|
||||
)
|
||||
config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps)
|
||||
else:
|
||||
config.enable_stream(rs.stream.color)
|
||||
|
||||
if self.use_depth:
|
||||
if self.fps and self.width and self.height:
|
||||
config.enable_stream(
|
||||
rs.stream.depth, self.width, self.height, rs.format.z16, self.fps
|
||||
)
|
||||
config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps)
|
||||
else:
|
||||
config.enable_stream(rs.stream.depth)
|
||||
|
||||
@@ -382,9 +358,7 @@ class IntelRealSenseCamera:
|
||||
actual_height = color_profile.height()
|
||||
|
||||
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
|
||||
if self.fps is not None and not math.isclose(
|
||||
self.fps, actual_fps, rel_tol=1e-3
|
||||
):
|
||||
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
# Using `OSError` since it's a broad that encompasses issues related to device communication
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
||||
@@ -404,9 +378,7 @@ class IntelRealSenseCamera:
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def read(
|
||||
self, temporary_color: str | None = None
|
||||
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
||||
def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
||||
"""Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3)
|
||||
of type `np.uint8`, contrarily to the pytorch format which is float channel first.
|
||||
|
||||
@@ -433,15 +405,11 @@ class IntelRealSenseCamera:
|
||||
color_frame = frame.get_color_frame()
|
||||
|
||||
if not color_frame:
|
||||
raise OSError(
|
||||
f"Can't capture color image from IntelRealSenseCamera({self.serial_number})."
|
||||
)
|
||||
raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).")
|
||||
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
|
||||
requested_color_mode = (
|
||||
self.color_mode if temporary_color is None else temporary_color
|
||||
)
|
||||
requested_color_mode = self.color_mode if temporary_color is None else temporary_color
|
||||
if requested_color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
|
||||
@@ -469,9 +437,7 @@ class IntelRealSenseCamera:
|
||||
if self.use_depth:
|
||||
depth_frame = frame.get_depth_frame()
|
||||
if not depth_frame:
|
||||
raise OSError(
|
||||
f"Can't capture depth image from IntelRealSenseCamera({self.serial_number})."
|
||||
)
|
||||
raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).")
|
||||
|
||||
depth_map = np.asanyarray(depth_frame.get_data())
|
||||
|
||||
@@ -513,9 +479,7 @@ class IntelRealSenseCamera:
|
||||
# TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here
|
||||
num_tries += 1
|
||||
time.sleep(1 / self.fps)
|
||||
if num_tries > self.fps and (
|
||||
self.thread.ident is None or not self.thread.is_alive()
|
||||
):
|
||||
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
|
||||
raise Exception(
|
||||
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
|
||||
)
|
||||
|
||||
@@ -31,14 +31,10 @@ from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
MAX_OPENCV_INDEX = 60
|
||||
|
||||
|
||||
def find_cameras(
|
||||
raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False
|
||||
) -> list[dict]:
|
||||
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
||||
cameras = []
|
||||
if platform.system() == "Linux":
|
||||
print(
|
||||
"Linux detected. Finding available camera indices through scanning '/dev/video*' ports"
|
||||
)
|
||||
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
||||
possible_ports = [str(port) for port in Path("/dev").glob("video*")]
|
||||
ports = _find_cameras(possible_ports, mock=mock)
|
||||
for port in ports:
|
||||
@@ -169,9 +165,7 @@ def save_images_from_cameras(
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
print(
|
||||
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
|
||||
)
|
||||
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
|
||||
|
||||
if time.perf_counter() - start_time > record_time_s:
|
||||
break
|
||||
@@ -198,7 +192,6 @@ class OpenCVCameraConfig:
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
@@ -208,12 +201,8 @@ class OpenCVCameraConfig:
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(
|
||||
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
|
||||
)
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
|
||||
|
||||
class OpenCVCamera:
|
||||
@@ -255,12 +244,7 @@ class OpenCVCamera:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
camera_index: int | str,
|
||||
config: OpenCVCameraConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
def __init__(self, camera_index: int | str, config: OpenCVCameraConfig | None = None, **kwargs):
|
||||
if config is None:
|
||||
config = OpenCVCameraConfig()
|
||||
|
||||
@@ -274,21 +258,16 @@ class OpenCVCamera:
|
||||
if platform.system() == "Linux":
|
||||
if isinstance(self.camera_index, int):
|
||||
self.port = Path(f"/dev/video{self.camera_index}")
|
||||
elif isinstance(self.camera_index, str) and is_valid_unix_path(
|
||||
self.camera_index
|
||||
):
|
||||
elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index):
|
||||
self.port = Path(self.camera_index)
|
||||
# Retrieve the camera index from a potentially symlinked path
|
||||
self.camera_index = get_camera_index_from_unix_port(self.port)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Please check the provided camera_index: {camera_index}"
|
||||
)
|
||||
raise ValueError(f"Please check the provided camera_index: {camera_index}")
|
||||
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.channels = config.channels
|
||||
self.color_mode = config.color_mode
|
||||
self.mock = config.mock
|
||||
|
||||
@@ -315,9 +294,7 @@ class OpenCVCamera:
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is already connected."
|
||||
)
|
||||
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
||||
|
||||
if self.mock:
|
||||
import tests.mock_cv2 as cv2
|
||||
@@ -328,11 +305,7 @@ class OpenCVCamera:
|
||||
# when other threads are used to save the images.
|
||||
cv2.setNumThreads(1)
|
||||
|
||||
camera_idx = (
|
||||
f"/dev/video{self.camera_index}"
|
||||
if platform.system() == "Linux"
|
||||
else self.camera_index
|
||||
)
|
||||
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
|
||||
# First create a temporary camera trying to access `camera_index`,
|
||||
# and verify it is a valid camera by calling `isOpened`.
|
||||
tmp_camera = cv2.VideoCapture(camera_idx)
|
||||
@@ -372,22 +345,16 @@ class OpenCVCamera:
|
||||
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
|
||||
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
|
||||
if self.fps is not None and not math.isclose(
|
||||
self.fps, actual_fps, rel_tol=1e-3
|
||||
):
|
||||
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
# Using `OSError` since it's a broad that encompasses issues related to device communication
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
|
||||
)
|
||||
if self.width is not None and not math.isclose(
|
||||
self.width, actual_width, rel_tol=1e-3
|
||||
):
|
||||
if self.width is not None and not math.isclose(self.width, actual_width, rel_tol=1e-3):
|
||||
raise OSError(
|
||||
f"Can't set {self.width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
|
||||
)
|
||||
if self.height is not None and not math.isclose(
|
||||
self.height, actual_height, rel_tol=1e-3
|
||||
):
|
||||
if self.height is not None and not math.isclose(self.height, actual_height, rel_tol=1e-3):
|
||||
raise OSError(
|
||||
f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
|
||||
)
|
||||
@@ -417,9 +384,7 @@ class OpenCVCamera:
|
||||
if not ret:
|
||||
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
|
||||
|
||||
requested_color_mode = (
|
||||
self.color_mode if temporary_color_mode is None else temporary_color_mode
|
||||
)
|
||||
requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
|
||||
|
||||
if requested_color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
|
||||
@@ -11,29 +11,19 @@ from copy import copy
|
||||
from functools import cache
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
from termcolor import colored
|
||||
|
||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import get_features_from_robot
|
||||
from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_writer
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
|
||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||
|
||||
|
||||
def log_control_info(
|
||||
robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None
|
||||
):
|
||||
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||
log_items = []
|
||||
if episode_index is not None:
|
||||
log_items.append(f"ep:{episode_index}")
|
||||
@@ -42,7 +32,7 @@ def log_control_info(
|
||||
|
||||
def log_dt(shortname, dt_val_s):
|
||||
nonlocal log_items, fps
|
||||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
|
||||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
|
||||
if fps is not None:
|
||||
actual_fps = 1 / dt_val_s
|
||||
if actual_fps < fps - 1:
|
||||
@@ -104,9 +94,7 @@ def predict_action(observation, policy, device, use_amp):
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if device.type == "cuda" and use_amp
|
||||
else nullcontext(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
@@ -129,22 +117,14 @@ def predict_action(observation, policy, device, use_amp):
|
||||
return action
|
||||
|
||||
|
||||
def init_keyboard_listener(assign_rewards=False):
|
||||
"""
|
||||
Initializes a keyboard listener to enable early termination of an episode
|
||||
or environment reset by pressing the right arrow key ('->'). This may require
|
||||
sudo permissions to allow the terminal to monitor keyboard events.
|
||||
|
||||
Args:
|
||||
assign_rewards (bool): If True, allows annotating the collected trajectory
|
||||
with a binary reward at the end of the episode to indicate success.
|
||||
"""
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["stop_recording"] = False
|
||||
if assign_rewards:
|
||||
events["next.reward"] = 0
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
@@ -162,22 +142,13 @@ def init_keyboard_listener(assign_rewards=False):
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print(
|
||||
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
|
||||
)
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
elif assign_rewards and key == keyboard.Key.space:
|
||||
events["next.reward"] = 1 if events["next.reward"] == 0 else 0
|
||||
print(
|
||||
"Space key pressed. Assigning new reward to the subsequent frames. New reward:",
|
||||
events["next.reward"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
@@ -190,12 +161,8 @@ def init_keyboard_listener(assign_rewards=False):
|
||||
def init_policy(pretrained_policy_name_or_path, policy_overrides):
|
||||
"""Instantiate the policy and load fps, device and use_amp from config yaml"""
|
||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||
hydra_cfg = init_hydra_config(
|
||||
pretrained_policy_path / "config.yaml", policy_overrides
|
||||
)
|
||||
policy = make_policy(
|
||||
hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path
|
||||
)
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
@@ -214,7 +181,7 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides):
|
||||
def warmup_record(
|
||||
robot,
|
||||
events,
|
||||
enable_teleoperation,
|
||||
enable_teloperation,
|
||||
warmup_time_s,
|
||||
display_cameras,
|
||||
fps,
|
||||
@@ -225,7 +192,7 @@ def warmup_record(
|
||||
display_cameras=display_cameras,
|
||||
events=events,
|
||||
fps=fps,
|
||||
teleoperate=enable_teleoperation,
|
||||
teleoperate=enable_teloperation,
|
||||
)
|
||||
|
||||
|
||||
@@ -239,7 +206,6 @@ def record_episode(
|
||||
device,
|
||||
use_amp,
|
||||
fps,
|
||||
record_delta_actions,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
@@ -251,7 +217,6 @@ def record_episode(
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
record_delta_actions=record_delta_actions,
|
||||
teleoperate=policy is None,
|
||||
)
|
||||
|
||||
@@ -262,13 +227,12 @@ def control_loop(
|
||||
control_time_s=None,
|
||||
teleoperate=False,
|
||||
display_cameras=False,
|
||||
dataset: LeRobotDataset | None = None,
|
||||
dataset=None,
|
||||
events=None,
|
||||
policy=None,
|
||||
device=None,
|
||||
use_amp=None,
|
||||
fps=None,
|
||||
record_delta_actions=False,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
@@ -283,22 +247,16 @@ def control_loop(
|
||||
if teleoperate and policy is not None:
|
||||
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
||||
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
raise ValueError(
|
||||
f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps})."
|
||||
)
|
||||
if dataset is not None and fps is not None and dataset["fps"] != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
if teleoperate:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
if record_delta_actions:
|
||||
action["action"] = action["action"] - current_joint_positions
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
@@ -310,23 +268,12 @@ def control_loop(
|
||||
action = {"action": action}
|
||||
|
||||
if dataset is not None:
|
||||
frame = {**observation, **action}
|
||||
if "next.reward" in events:
|
||||
frame["next.reward"] = events["next.reward"]
|
||||
frame["next.done"] = (events["next.reward"] == 1) or (
|
||||
events["exit_early"]
|
||||
)
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# if frame["next.done"]:
|
||||
# break
|
||||
add_frame(dataset, observation, action)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(
|
||||
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
)
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
if fps is not None:
|
||||
@@ -350,8 +297,6 @@ def reset_environment(robot, events, reset_time_s):
|
||||
|
||||
timestamp = 0
|
||||
start_vencod_t = time.perf_counter()
|
||||
if "next.reward" in events:
|
||||
events["next.reward"] = 0
|
||||
|
||||
# Wait if necessary
|
||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||
@@ -364,16 +309,6 @@ def reset_environment(robot, events, reset_time_s):
|
||||
break
|
||||
|
||||
|
||||
def reset_follower_position(robot: Robot, target_position):
|
||||
current_position = robot.follower_arms["main"].read("Present_Position")
|
||||
trajectory = torch.from_numpy(
|
||||
np.linspace(current_position, target_position, 50)
|
||||
) # NOTE: 30 is just an aribtrary number
|
||||
for pose in trajectory:
|
||||
robot.send_action(pose)
|
||||
busy_wait(0.015)
|
||||
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
robot.disconnect()
|
||||
|
||||
@@ -389,47 +324,7 @@ def sanity_check_dataset_name(repo_id, policy):
|
||||
_, dataset_name = repo_id.split("/")
|
||||
# either repo_id doesnt start with "eval_" and there is no policy
|
||||
# or repo_id starts with "eval_" and there is a policy
|
||||
|
||||
# Check if dataset_name starts with "eval_" but policy is missing
|
||||
if dataset_name.startswith("eval_") and policy is None:
|
||||
if dataset_name.startswith("eval_") == (policy is None):
|
||||
raise ValueError(
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
|
||||
)
|
||||
|
||||
# Check if dataset_name does not start with "eval_" but policy is provided
|
||||
if not dataset_name.startswith("eval_") and policy is not None:
|
||||
raise ValueError(
|
||||
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy})."
|
||||
)
|
||||
|
||||
|
||||
def sanity_check_dataset_robot_compatibility(
|
||||
dataset: LeRobotDataset,
|
||||
robot: Robot,
|
||||
fps: int,
|
||||
use_videos: bool,
|
||||
extra_features: dict = None,
|
||||
) -> None:
|
||||
features_from_robot = get_features_from_robot(robot, use_videos)
|
||||
if extra_features is not None:
|
||||
features_from_robot.update(extra_features)
|
||||
|
||||
fields = [
|
||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||
("fps", dataset.fps, fps),
|
||||
("features", dataset.features, features_from_robot),
|
||||
]
|
||||
|
||||
mismatches = []
|
||||
for field, dataset_value, present_value in fields:
|
||||
diff = DeepDiff(
|
||||
dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]
|
||||
)
|
||||
if diff:
|
||||
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
||||
|
||||
if mismatches:
|
||||
raise ValueError(
|
||||
"Dataset metadata compatibility check failed with mismatches:\n"
|
||||
+ "\n".join(mismatches)
|
||||
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
|
||||
)
|
||||
|
||||
@@ -8,10 +8,7 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
PROTOCOL_VERSION = 2.0
|
||||
@@ -146,9 +143,7 @@ NUM_READ_RETRY = 10
|
||||
NUM_WRITE_RETRY = 10
|
||||
|
||||
|
||||
def convert_degrees_to_steps(
|
||||
degrees: float | np.ndarray, models: str | list[str]
|
||||
) -> np.ndarray:
|
||||
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
|
||||
"""This function converts the degree range to the step range for indicating motors rotation.
|
||||
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
@@ -383,9 +378,7 @@ class DynamixelMotorsBus:
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self.read_with_motor_ids(
|
||||
self.motor_models, [idx], "ID", num_retry=num_retry
|
||||
)[0]
|
||||
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
@@ -401,9 +394,7 @@ class DynamixelMotorsBus:
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(
|
||||
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
|
||||
)
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
@@ -424,9 +415,7 @@ class DynamixelMotorsBus:
|
||||
def set_calibration(self, calibration: dict[str, list]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration_autocorrect(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
|
||||
|
||||
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
|
||||
@@ -439,9 +428,7 @@ class DynamixelMotorsBus:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
return values
|
||||
|
||||
def apply_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
||||
a "zero position" at 0 degree.
|
||||
|
||||
@@ -516,9 +503,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def autocorrect_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
|
||||
|
||||
Some motors might have values outside of expected maximum bounds after calibration.
|
||||
@@ -560,23 +545,15 @@ class DynamixelMotorsBus:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from initial range to range [-180, 180] degrees
|
||||
calib_val = (
|
||||
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
)
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
|
||||
calib_val < UPPER_BOUND_DEGREE
|
||||
)
|
||||
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
|
||||
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
|
||||
# (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
|
||||
low_factor = (
|
||||
-(resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
upp_factor = (
|
||||
(resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution
|
||||
upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
start_pos = self.calibration["start_pos"][calib_idx]
|
||||
@@ -584,9 +561,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
# Convert from initial range to range [0, 100] in %
|
||||
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
|
||||
calib_val < UPPER_BOUND_LINEAR
|
||||
)
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [0, 100] %
|
||||
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
|
||||
@@ -602,27 +577,19 @@ class DynamixelMotorsBus:
|
||||
factor = math.ceil(low_factor)
|
||||
|
||||
if factor > upp_factor:
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
else:
|
||||
factor = math.ceil(upp_factor)
|
||||
|
||||
if factor > low_factor:
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
in_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
@@ -632,9 +599,7 @@ class DynamixelMotorsBus:
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
self.calibration["homing_offset"][calib_idx] += resolution * factor
|
||||
|
||||
def revert_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Inverse of `apply_calibration`."""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
@@ -673,9 +638,7 @@ class DynamixelMotorsBus:
|
||||
values = np.round(values).astype(np.int32)
|
||||
return values
|
||||
|
||||
def read_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
|
||||
):
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
@@ -777,9 +740,7 @@ class DynamixelMotorsBus:
|
||||
values = self.apply_calibration_autocorrect(values, motor_names)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "read", data_name, motor_names
|
||||
)
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the data was received
|
||||
@@ -788,9 +749,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def write_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
|
||||
):
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
@@ -819,12 +778,7 @@ class DynamixelMotorsBus:
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def write(
|
||||
self,
|
||||
data_name,
|
||||
values: int | float | np.ndarray,
|
||||
motor_names: str | list[str] | None = None,
|
||||
):
|
||||
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
@@ -885,9 +839,7 @@ class DynamixelMotorsBus:
|
||||
)
|
||||
|
||||
# log the number of seconds it took to write the data to the motors
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "write", data_name, motor_names
|
||||
)
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# TODO(rcadene): should we log the time before sending the write command?
|
||||
|
||||
@@ -8,10 +8,7 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
PROTOCOL_VERSION = 0
|
||||
@@ -125,9 +122,7 @@ NUM_READ_RETRY = 20
|
||||
NUM_WRITE_RETRY = 20
|
||||
|
||||
|
||||
def convert_degrees_to_steps(
|
||||
degrees: float | np.ndarray, models: str | list[str]
|
||||
) -> np.ndarray:
|
||||
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
|
||||
"""This function converts the degree range to the step range for indicating motors rotation.
|
||||
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
@@ -363,9 +358,7 @@ class FeetechMotorsBus:
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self.read_with_motor_ids(
|
||||
self.motor_models, [idx], "ID", num_retry=num_retry
|
||||
)[0]
|
||||
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
@@ -381,9 +374,7 @@ class FeetechMotorsBus:
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(
|
||||
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
|
||||
)
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
@@ -404,9 +395,7 @@ class FeetechMotorsBus:
|
||||
def set_calibration(self, calibration: dict[str, list]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration_autocorrect(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
|
||||
|
||||
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
|
||||
@@ -419,9 +408,7 @@ class FeetechMotorsBus:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
return values
|
||||
|
||||
def apply_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
||||
a "zero position" at 0 degree.
|
||||
|
||||
@@ -495,9 +482,7 @@ class FeetechMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def autocorrect_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
|
||||
|
||||
Some motors might have values outside of expected maximum bounds after calibration.
|
||||
@@ -536,26 +521,18 @@ class FeetechMotorsBus:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from initial range to range [-180, 180] degrees
|
||||
calib_val = (
|
||||
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
)
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
|
||||
calib_val < UPPER_BOUND_DEGREE
|
||||
)
|
||||
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
|
||||
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
|
||||
# (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
|
||||
low_factor = (
|
||||
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
|
||||
- values[i]
|
||||
- homing_offset
|
||||
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
upp_factor = (
|
||||
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
|
||||
- values[i]
|
||||
- homing_offset
|
||||
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
@@ -564,9 +541,7 @@ class FeetechMotorsBus:
|
||||
|
||||
# Convert from initial range to range [0, 100] in %
|
||||
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
|
||||
calib_val < UPPER_BOUND_LINEAR
|
||||
)
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [0, 100] %
|
||||
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
|
||||
@@ -582,27 +557,19 @@ class FeetechMotorsBus:
|
||||
factor = math.ceil(low_factor)
|
||||
|
||||
if factor > upp_factor:
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
else:
|
||||
factor = math.ceil(upp_factor)
|
||||
|
||||
if factor > low_factor:
|
||||
raise ValueError(
|
||||
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
|
||||
)
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
in_range_str = (
|
||||
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
)
|
||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
@@ -612,9 +579,7 @@ class FeetechMotorsBus:
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
self.calibration["homing_offset"][calib_idx] += resolution * factor
|
||||
|
||||
def revert_calibration(
|
||||
self, values: np.ndarray | list, motor_names: list[str] | None
|
||||
):
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Inverse of `apply_calibration`."""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
@@ -690,9 +655,7 @@ class FeetechMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def read_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
|
||||
):
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
@@ -797,9 +760,7 @@ class FeetechMotorsBus:
|
||||
values = self.apply_calibration_autocorrect(values, motor_names)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "read", data_name, motor_names
|
||||
)
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the data was received
|
||||
@@ -808,9 +769,7 @@ class FeetechMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def write_with_motor_ids(
|
||||
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
|
||||
):
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
@@ -839,12 +798,7 @@ class FeetechMotorsBus:
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def write(
|
||||
self,
|
||||
data_name,
|
||||
values: int | float | np.ndarray,
|
||||
motor_names: str | list[str] | None = None,
|
||||
):
|
||||
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
@@ -905,9 +859,7 @@ class FeetechMotorsBus:
|
||||
)
|
||||
|
||||
# log the number of seconds it took to write the data to the motors
|
||||
delta_ts_name = get_log_name(
|
||||
"delta_timestamp_s", "write", data_name, motor_names
|
||||
)
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# TODO(rcadene): should we log the time before sending the write command?
|
||||
|
||||
@@ -10,7 +10,9 @@ from lerobot.common.robot_devices.motors.dynamixel import (
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
|
||||
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
URL_TEMPLATE = (
|
||||
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
)
|
||||
|
||||
# The following positions are provided in nominal degree range ]-180, +180[
|
||||
# For more info on these constants, see comments in the code where they get used.
|
||||
@@ -21,9 +23,7 @@ ROTATED_POSITION_DEGREE = 90
|
||||
def assert_drive_mode(drive_mode):
|
||||
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
||||
if not np.all(np.isin(drive_mode, [0, 1])):
|
||||
raise ValueError(
|
||||
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
|
||||
)
|
||||
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
|
||||
|
||||
|
||||
def apply_drive_mode(position, drive_mode):
|
||||
@@ -64,16 +64,12 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
```
|
||||
"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to zero position")
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
|
||||
@@ -94,15 +90,10 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
|
||||
# of the previous motor in the kinetic chain.
|
||||
print("\nMove arm to rotated target position")
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
rotated_target_pos = convert_degrees_to_steps(
|
||||
ROTATED_POSITION_DEGREE, arm.motor_models
|
||||
)
|
||||
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
||||
|
||||
# Find drive mode by rotating each motor by a quarter of a turn.
|
||||
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
||||
@@ -111,15 +102,11 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
|
||||
# Re-compute homing offset to take into account drive mode
|
||||
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
|
||||
rotated_nearest_pos = compute_nearest_rounded_position(
|
||||
rotated_drived_pos, arm.motor_models
|
||||
)
|
||||
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models)
|
||||
homing_offset = rotated_target_pos - rotated_nearest_pos
|
||||
|
||||
print("\nMove arm to rest position")
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
|
||||
|
||||
@@ -12,7 +12,9 @@ from lerobot.common.robot_devices.motors.feetech import (
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
|
||||
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
URL_TEMPLATE = (
|
||||
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
)
|
||||
|
||||
# The following positions are provided in nominal degree range ]-180, +180[
|
||||
# For more info on these constants, see comments in the code where they get used.
|
||||
@@ -23,9 +25,7 @@ ROTATED_POSITION_DEGREE = 90
|
||||
def assert_drive_mode(drive_mode):
|
||||
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
||||
if not np.all(np.isin(drive_mode, [0, 1])):
|
||||
raise ValueError(
|
||||
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
|
||||
)
|
||||
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
|
||||
|
||||
|
||||
def apply_drive_mode(position, drive_mode):
|
||||
@@ -126,9 +126,7 @@ def apply_offset(calib, offset):
|
||||
return calib
|
||||
|
||||
|
||||
def run_arm_auto_calibration(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
if robot_type == "so100":
|
||||
return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type)
|
||||
elif robot_type == "moss":
|
||||
@@ -137,27 +135,18 @@ def run_arm_auto_calibration(
|
||||
raise ValueError(robot_type)
|
||||
|
||||
|
||||
def run_arm_auto_calibration_so100(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
|
||||
if not (robot_type == "so100" and arm_type == "follower"):
|
||||
raise NotImplementedError(
|
||||
"Auto calibration only supports the follower of so100 arms for now."
|
||||
)
|
||||
raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.")
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to initial position")
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# Lower the acceleration of the motors (in [0,254])
|
||||
@@ -204,16 +193,11 @@ def run_arm_auto_calibration_so100(
|
||||
|
||||
print("Calibrate elbow_flex")
|
||||
calib["elbow_flex"] = move_to_calibrate(
|
||||
arm,
|
||||
"elbow_flex",
|
||||
positive_first=False,
|
||||
in_between_move_hook=in_between_move_hook,
|
||||
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook
|
||||
)
|
||||
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
|
||||
|
||||
arm.write(
|
||||
"Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex"
|
||||
)
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
|
||||
time.sleep(1)
|
||||
|
||||
def in_between_move_hook():
|
||||
@@ -241,30 +225,18 @@ def run_arm_auto_calibration_so100(
|
||||
}
|
||||
arm.write("Goal_Position", list(positions.values()), list(positions.keys()))
|
||||
|
||||
arm.write(
|
||||
"Goal_Position",
|
||||
round(calib["shoulder_lift"]["zero_pos"] - 1600),
|
||||
"shoulder_lift",
|
||||
)
|
||||
arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift")
|
||||
time.sleep(2)
|
||||
arm.write(
|
||||
"Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex"
|
||||
)
|
||||
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
|
||||
time.sleep(2)
|
||||
arm.write(
|
||||
"Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex"
|
||||
)
|
||||
arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex")
|
||||
time.sleep(2)
|
||||
arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper")
|
||||
time.sleep(2)
|
||||
|
||||
print("Calibrate wrist_roll")
|
||||
calib["wrist_roll"] = move_to_calibrate(
|
||||
arm,
|
||||
"wrist_roll",
|
||||
invert_drive_mode=True,
|
||||
positive_first=False,
|
||||
while_move_hook=while_move_hook,
|
||||
arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook
|
||||
)
|
||||
|
||||
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
|
||||
@@ -274,9 +246,7 @@ def run_arm_auto_calibration_so100(
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex")
|
||||
arm.write(
|
||||
"Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift"
|
||||
)
|
||||
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
|
||||
time.sleep(1)
|
||||
@@ -305,27 +275,18 @@ def run_arm_auto_calibration_so100(
|
||||
return calib_dict
|
||||
|
||||
|
||||
def run_arm_auto_calibration_moss(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
|
||||
if not (robot_type == "moss" and arm_type == "follower"):
|
||||
raise NotImplementedError(
|
||||
"Auto calibration only supports the follower of moss arms for now."
|
||||
)
|
||||
raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.")
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to initial position")
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# Lower the acceleration of the motors (in [0,254])
|
||||
@@ -409,12 +370,8 @@ def run_arm_auto_calibration_moss(
|
||||
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
|
||||
time.sleep(1)
|
||||
arm.write(
|
||||
"Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift"
|
||||
)
|
||||
arm.write(
|
||||
"Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex"
|
||||
)
|
||||
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift")
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex")
|
||||
time.sleep(2)
|
||||
|
||||
calib_modes = []
|
||||
@@ -441,9 +398,7 @@ def run_arm_auto_calibration_moss(
|
||||
return calib_dict
|
||||
|
||||
|
||||
def run_arm_manual_calibration(
|
||||
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
|
||||
):
|
||||
def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
"""This function ensures that a neural network trained on data collected on a given robot
|
||||
can work on another robot. For instance before calibration, setting a same goal position
|
||||
for each motor of two different robots will get two very different positions. But after calibration,
|
||||
@@ -466,16 +421,12 @@ def run_arm_manual_calibration(
|
||||
```
|
||||
"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError(
|
||||
"To run calibration, the torque must be disabled on all motors."
|
||||
)
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to zero position")
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
|
||||
@@ -495,15 +446,10 @@ def run_arm_manual_calibration(
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
|
||||
# of the previous motor in the kinetic chain.
|
||||
print("\nMove arm to rotated target position")
|
||||
print(
|
||||
"See: "
|
||||
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
rotated_target_pos = convert_degrees_to_steps(
|
||||
ROTATED_POSITION_DEGREE, arm.motor_models
|
||||
)
|
||||
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
||||
|
||||
# Find drive mode by rotating each motor by a quarter of a turn.
|
||||
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
||||
@@ -515,9 +461,7 @@ def run_arm_manual_calibration(
|
||||
homing_offset = rotated_target_pos - rotated_drived_pos
|
||||
|
||||
print("\nMove arm to rest position")
|
||||
print(
|
||||
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
|
||||
)
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
|
||||
|
||||
@@ -18,16 +18,11 @@ import torch
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
|
||||
|
||||
def ensure_safe_goal_position(
|
||||
goal_pos: torch.Tensor,
|
||||
present_pos: torch.Tensor,
|
||||
max_relative_target: float | list[float],
|
||||
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
|
||||
):
|
||||
# Cap relative action target magnitude for safety.
|
||||
diff = goal_pos - present_pos
|
||||
@@ -37,7 +32,7 @@ def ensure_safe_goal_position(
|
||||
safe_goal_pos = present_pos + safe_diff
|
||||
|
||||
if not torch.allclose(goal_pos, safe_goal_pos):
|
||||
logging.debug(
|
||||
logging.warning(
|
||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||
f" requested relative goal position target: {diff}\n"
|
||||
f" clamped relative goal position target: {safe_diff}"
|
||||
@@ -72,14 +67,8 @@ class ManipulatorRobotConfig:
|
||||
# gripper is not put in torque mode.
|
||||
gripper_open_degree: float | None = None
|
||||
|
||||
joint_position_relative_bounds: dict[np.ndarray] | None = None
|
||||
|
||||
def __setattr__(self, prop: str, val):
|
||||
if (
|
||||
prop == "max_relative_target"
|
||||
and val is not None
|
||||
and isinstance(val, Sequence)
|
||||
):
|
||||
if prop == "max_relative_target" and val is not None and isinstance(val, Sequence):
|
||||
for name in self.follower_arms:
|
||||
if len(self.follower_arms[name].motors) != len(val):
|
||||
raise ValueError(
|
||||
@@ -89,16 +78,11 @@ class ManipulatorRobotConfig:
|
||||
"Note: This feature does not yet work with robots where different follower arms have "
|
||||
"different numbers of motors."
|
||||
)
|
||||
if prop == "joint_position_relative_bounds" and val is not None:
|
||||
for key in val:
|
||||
val[key] = torch.tensor(val[key])
|
||||
super().__setattr__(prop, val)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.robot_type not in ["koch", "koch_bimanual", "aloha", "so100", "moss"]:
|
||||
raise ValueError(
|
||||
f"Provided robot type ({self.robot_type}) is not supported."
|
||||
)
|
||||
raise ValueError(f"Provided robot type ({self.robot_type}) is not supported.")
|
||||
|
||||
|
||||
class ManipulatorRobot:
|
||||
@@ -242,42 +226,6 @@ class ManipulatorRobot:
|
||||
self.is_connected = False
|
||||
self.logs = {}
|
||||
|
||||
def get_motor_names(self, arm: dict[str, MotorsBus]) -> list:
|
||||
return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
key = f"observation.images.{cam_key}"
|
||||
cam_ft[key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
@property
|
||||
def motor_features(self) -> dict:
|
||||
action_names = self.get_motor_names(self.leader_arms)
|
||||
state_names = self.get_motor_names(self.leader_arms)
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(action_names),),
|
||||
"names": action_names,
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(state_names),),
|
||||
"names": state_names,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def features(self):
|
||||
return {**self.motor_features, **self.camera_features}
|
||||
|
||||
@property
|
||||
def has_camera(self):
|
||||
return len(self.cameras) > 0
|
||||
@@ -352,9 +300,7 @@ class ManipulatorRobot:
|
||||
# to squeeze the gripper and have it spring back to an open position on its own.
|
||||
for name in self.leader_arms:
|
||||
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
|
||||
self.leader_arms[name].write(
|
||||
"Goal_Position", self.config.gripper_open_degree, "gripper"
|
||||
)
|
||||
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
|
||||
|
||||
# Check both arms can be read
|
||||
for name in self.follower_arms:
|
||||
@@ -386,26 +332,18 @@ class ManipulatorRobot:
|
||||
print(f"Missing calibration file '{arm_calib_path}'")
|
||||
|
||||
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
||||
from lerobot.common.robot_devices.robots.dynamixel_calibration import (
|
||||
run_arm_calibration,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration
|
||||
|
||||
calibration = run_arm_calibration(
|
||||
arm, self.robot_type, name, arm_type
|
||||
)
|
||||
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
|
||||
|
||||
elif self.robot_type in ["so100", "moss"]:
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||
run_arm_manual_calibration,
|
||||
)
|
||||
|
||||
calibration = run_arm_manual_calibration(
|
||||
arm, self.robot_type, name, arm_type
|
||||
)
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
|
||||
print(
|
||||
f"Calibration is done! Saving calibration file '{arm_calib_path}'"
|
||||
)
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(arm_calib_path, "w") as f:
|
||||
json.dump(calibration, f)
|
||||
@@ -424,17 +362,13 @@ class ManipulatorRobot:
|
||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError(
|
||||
"To run set robot preset, the torque must be disabled on all motors."
|
||||
)
|
||||
raise ValueError("To run set robot preset, the torque must be disabled on all motors.")
|
||||
|
||||
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
|
||||
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
|
||||
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
|
||||
all_motors_except_gripper = [
|
||||
name for name in arm.motor_names if name != "gripper"
|
||||
]
|
||||
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
|
||||
if len(all_motors_except_gripper) > 0:
|
||||
# 4 corresponds to Extended Position on Koch motors
|
||||
arm.write("Operating_Mode", 4, all_motors_except_gripper)
|
||||
@@ -463,9 +397,7 @@ class ManipulatorRobot:
|
||||
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
|
||||
# so that we can use it as a trigger to close the gripper of the follower arms.
|
||||
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
|
||||
self.leader_arms[name].write(
|
||||
"Goal_Position", self.config.gripper_open_degree, "gripper"
|
||||
)
|
||||
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
|
||||
|
||||
def set_aloha_robot_preset(self):
|
||||
def set_shadow_(arm):
|
||||
@@ -495,15 +427,11 @@ class ManipulatorRobot:
|
||||
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
|
||||
all_motors_except_gripper = [
|
||||
name
|
||||
for name in self.follower_arms[name].motor_names
|
||||
if name != "gripper"
|
||||
name for name in self.follower_arms[name].motor_names if name != "gripper"
|
||||
]
|
||||
if len(all_motors_except_gripper) > 0:
|
||||
# 4 corresponds to Extended Position on Aloha motors
|
||||
self.follower_arms[name].write(
|
||||
"Operating_Mode", 4, all_motors_except_gripper
|
||||
)
|
||||
self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper)
|
||||
|
||||
# Use 'position control current based' for follower gripper to be limited by the limit of the current.
|
||||
# It can grasp an object without forcing too much even tho,
|
||||
@@ -551,9 +479,7 @@ class ManipulatorRobot:
|
||||
before_lread_t = time.perf_counter()
|
||||
leader_pos[name] = self.leader_arms[name].read("Present_Position")
|
||||
leader_pos[name] = torch.from_numpy(leader_pos[name])
|
||||
self.logs[f"read_leader_{name}_pos_dt_s"] = (
|
||||
time.perf_counter() - before_lread_t
|
||||
)
|
||||
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
|
||||
|
||||
# Send goal position to the follower
|
||||
follower_goal_pos = {}
|
||||
@@ -561,31 +487,19 @@ class ManipulatorRobot:
|
||||
before_fwrite_t = time.perf_counter()
|
||||
goal_pos = leader_pos[name]
|
||||
|
||||
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||
if self.config.joint_position_relative_bounds is not None:
|
||||
goal_pos = torch.clamp(
|
||||
goal_pos,
|
||||
self.config.joint_position_relative_bounds["min"],
|
||||
self.config.joint_position_relative_bounds["max"],
|
||||
)
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.follower_arms[name].read("Present_Position")
|
||||
present_pos = torch.from_numpy(present_pos)
|
||||
goal_pos = ensure_safe_goal_position(
|
||||
goal_pos, present_pos, self.config.max_relative_target
|
||||
)
|
||||
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
|
||||
|
||||
# Used when record_data=True
|
||||
follower_goal_pos[name] = goal_pos
|
||||
|
||||
goal_pos = goal_pos.numpy().astype(np.int32)
|
||||
self.follower_arms[name].write("Goal_Position", goal_pos)
|
||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = (
|
||||
time.perf_counter() - before_fwrite_t
|
||||
)
|
||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
|
||||
|
||||
# Early exit when recording data is not requested
|
||||
if not record_data:
|
||||
@@ -598,9 +512,7 @@ class ManipulatorRobot:
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||
follower_pos[name] = torch.from_numpy(follower_pos[name])
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = (
|
||||
time.perf_counter() - before_fread_t
|
||||
)
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
@@ -622,12 +534,8 @@ class ManipulatorRobot:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
@@ -651,9 +559,7 @@ class ManipulatorRobot:
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||
follower_pos[name] = torch.from_numpy(follower_pos[name])
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = (
|
||||
time.perf_counter() - before_fread_t
|
||||
)
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
@@ -668,12 +574,8 @@ class ManipulatorRobot:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries and format to pytorch
|
||||
obs_dict = {}
|
||||
@@ -706,29 +608,18 @@ class ManipulatorRobot:
|
||||
goal_pos = action[from_idx:to_idx]
|
||||
from_idx = to_idx
|
||||
|
||||
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||
if self.config.joint_position_relative_bounds is not None:
|
||||
goal_pos = torch.clamp(
|
||||
goal_pos,
|
||||
self.config.joint_position_relative_bounds["min"],
|
||||
self.config.joint_position_relative_bounds["max"],
|
||||
)
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.follower_arms[name].read("Present_Position")
|
||||
present_pos = torch.from_numpy(present_pos)
|
||||
goal_pos = ensure_safe_goal_position(
|
||||
goal_pos, present_pos, self.config.max_relative_target
|
||||
)
|
||||
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
|
||||
|
||||
# Save tensor to concat and return
|
||||
action_sent.append(goal_pos)
|
||||
|
||||
# Send goal position to each follower
|
||||
goal_pos = goal_pos.numpy().astype(np.int32)
|
||||
|
||||
self.follower_arms[name].write("Goal_Position", goal_pos)
|
||||
|
||||
return torch.cat(action_sent)
|
||||
|
||||
@@ -60,9 +60,7 @@ class StretchRobot(StretchAPI):
|
||||
def connect(self) -> None:
|
||||
self.is_connected = self.startup()
|
||||
if not self.is_connected:
|
||||
print(
|
||||
"Another process is already using Stretch. Try running 'stretch_free_robot_process.py'"
|
||||
)
|
||||
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
|
||||
raise ConnectionError()
|
||||
|
||||
for name in self.cameras:
|
||||
@@ -70,9 +68,7 @@ class StretchRobot(StretchAPI):
|
||||
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
||||
|
||||
if not self.is_connected:
|
||||
print(
|
||||
"Could not connect to the cameras, check that all cameras are plugged-in."
|
||||
)
|
||||
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
||||
raise ConnectionError()
|
||||
|
||||
self.run_calibration()
|
||||
@@ -117,12 +113,8 @@ class StretchRobot(StretchAPI):
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
@@ -166,12 +158,8 @@ class StretchRobot(StretchAPI):
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
|
||||
"delta_timestamp_s"
|
||||
]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = (
|
||||
time.perf_counter() - before_camread_t
|
||||
)
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries
|
||||
obs_dict = {}
|
||||
|
||||
@@ -11,7 +11,6 @@ def get_arm_id(name, arm_type):
|
||||
class Robot(Protocol):
|
||||
# TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes
|
||||
robot_type: str
|
||||
features: dict
|
||||
|
||||
def connect(self): ...
|
||||
def run_calibration(self): ...
|
||||
|
||||
@@ -34,8 +34,7 @@ class RobotDeviceNotConnectedError(Exception):
|
||||
"""Exception raised when the robot device is not connected."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message="This robot device is not connected. Try calling `robot_device.connect()` first.",
|
||||
self, message="This robot device is not connected. Try calling `robot_device.connect()` first."
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
@@ -17,9 +17,7 @@ import importlib
|
||||
import logging
|
||||
|
||||
|
||||
def is_package_available(
|
||||
pkg_name: str, return_version: bool = False
|
||||
) -> tuple[bool, str] | bool:
|
||||
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
|
||||
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
|
||||
Check if the package spec exists and grab its version to avoid importing a local directory.
|
||||
**Note:** this doesn't work for all packages.
|
||||
|
||||
@@ -22,8 +22,6 @@ def write_video(video_path, stacked_frames, fps):
|
||||
# Filter out DeprecationWarnings raised from pkg_resources
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
"pkg_resources is deprecated as an API",
|
||||
category=DeprecationWarning,
|
||||
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
|
||||
)
|
||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||
|
||||
@@ -18,7 +18,6 @@ import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import random
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@@ -116,11 +115,11 @@ def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||
set_global_random_state(random_state_dict)
|
||||
|
||||
|
||||
def init_logging(log_file=None):
|
||||
def init_logging():
|
||||
def custom_format(record):
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
fnameline = f"{record.pathname}:{record.lineno}"
|
||||
message = f"{record.levelname} [PID: {os.getpid()}] {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
return message
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -134,12 +133,6 @@ def init_logging(log_file=None):
|
||||
console_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
if log_file is not None:
|
||||
# File handler
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(file_handler)
|
||||
|
||||
|
||||
def format_big_number(num, precision=0):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
@@ -162,16 +155,11 @@ def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
except ValueError: # most likely because path1 is not a subpath of path2
|
||||
common_parts = Path(osp.commonpath([path1, path2])).parts
|
||||
return Path(
|
||||
"/".join(
|
||||
[".."] * (len(path2.parts) - len(common_parts))
|
||||
+ list(path1.parts[len(common_parts) :])
|
||||
)
|
||||
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
||||
)
|
||||
|
||||
|
||||
def init_hydra_config(
|
||||
config_path: str, overrides: list[str] | None = None
|
||||
) -> DictConfig:
|
||||
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
|
||||
"""Initialize a Hydra config given only the path to the relevant config file.
|
||||
|
||||
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
|
||||
@@ -180,11 +168,7 @@ def init_hydra_config(
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
# Hydra needs a path relative to this file.
|
||||
hydra.initialize(
|
||||
str(
|
||||
_relative_path_between(
|
||||
Path(config_path).absolute().parent, Path(__file__).absolute().parent
|
||||
)
|
||||
),
|
||||
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)),
|
||||
version_base="1.2",
|
||||
)
|
||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||
@@ -198,26 +182,10 @@ def print_cuda_memory_usage():
|
||||
gc.collect()
|
||||
# Also clear the cache if you want to fully release the memory
|
||||
torch.cuda.empty_cache()
|
||||
print(
|
||||
"Current GPU Memory Allocated: {:.2f} MB".format(
|
||||
torch.cuda.memory_allocated(0) / 1024**2
|
||||
)
|
||||
)
|
||||
print(
|
||||
"Maximum GPU Memory Allocated: {:.2f} MB".format(
|
||||
torch.cuda.max_memory_allocated(0) / 1024**2
|
||||
)
|
||||
)
|
||||
print(
|
||||
"Current GPU Memory Reserved: {:.2f} MB".format(
|
||||
torch.cuda.memory_reserved(0) / 1024**2
|
||||
)
|
||||
)
|
||||
print(
|
||||
"Maximum GPU Memory Reserved: {:.2f} MB".format(
|
||||
torch.cuda.max_memory_reserved(0) / 1024**2
|
||||
)
|
||||
)
|
||||
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
|
||||
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
|
||||
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
|
||||
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
|
||||
|
||||
|
||||
def capture_timestamp_utc():
|
||||
@@ -249,33 +217,3 @@ def log_say(text, play_sounds, blocking=False):
|
||||
|
||||
if play_sounds:
|
||||
say(text, blocking)
|
||||
|
||||
|
||||
class TimerManager:
|
||||
def __init__(
|
||||
self,
|
||||
elapsed_time_list: list[float] | None = None,
|
||||
label="Elapsed time",
|
||||
log=True,
|
||||
):
|
||||
self.label = label
|
||||
self.elapsed_time_list = elapsed_time_list
|
||||
self.log = log
|
||||
self.elapsed = 0.0
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.elapsed: float = time.perf_counter() - self.start
|
||||
|
||||
if self.elapsed_time_list is not None:
|
||||
self.elapsed_time_list.append(self.elapsed)
|
||||
|
||||
if self.log:
|
||||
print(f"{self.label}: {self.elapsed:.6f} seconds")
|
||||
|
||||
@property
|
||||
def elapsed_seconds(self):
|
||||
return self.elapsed
|
||||
|
||||
@@ -2,7 +2,6 @@ defaults:
|
||||
- _self_
|
||||
- env: pusht
|
||||
- policy: diffusion
|
||||
- robot: so100
|
||||
|
||||
hydra:
|
||||
run:
|
||||
|
||||
30
lerobot/configs/env/maniskill_example.yaml
vendored
30
lerobot/configs/env/maniskill_example.yaml
vendored
@@ -1,30 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 400
|
||||
|
||||
env:
|
||||
name: maniskill/pushcube
|
||||
task: PushCube-v1
|
||||
image_size: 64
|
||||
control_mode: pd_ee_delta_pose
|
||||
state_dim: 25
|
||||
action_dim: 7
|
||||
fps: ${fps}
|
||||
obs: rgb
|
||||
render_mode: rgb_array
|
||||
render_size: 64
|
||||
device: cuda
|
||||
|
||||
reward_classifier:
|
||||
pretrained_path: null
|
||||
config_path: null
|
||||
|
||||
wrapper:
|
||||
joint_masking_action_space: null
|
||||
delta_action: null
|
||||
|
||||
video_record:
|
||||
enabled: false
|
||||
record_dir: maniskill_videos
|
||||
trajectory_name: trajectory
|
||||
fps: ${fps}
|
||||
46
lerobot/configs/env/so100_real.yaml
vendored
46
lerobot/configs/env/so100_real.yaml
vendored
@@ -1,50 +1,10 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 10
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 15
|
||||
action_dim: 3
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
device: mps
|
||||
|
||||
wrapper:
|
||||
crop_params_dict:
|
||||
observation.images.front: [171, 207, 116, 251]
|
||||
observation.images.side: [232, 200, 142, 204]
|
||||
resize_size: [128, 128]
|
||||
control_time_s: 10
|
||||
reset_follower_pos: false
|
||||
use_relative_joint_positions: true
|
||||
reset_time_s: 5
|
||||
display_cameras: false
|
||||
delta_action: null #0.3
|
||||
joint_masking_action_space: null #[1, 1, 1, 1, 0, 0] # disable wrist and gripper
|
||||
add_joint_velocity_to_observation: true
|
||||
add_ee_pose_to_observation: true
|
||||
|
||||
# If null then the teleoperation will be used to reset the robot
|
||||
# Bounds for pushcube_gamepad_lerobot15 dataset and experiments
|
||||
# fixed_reset_joint_positions: [-19.86, 103.19, 117.33, 42.7, 13.89, 0.297]
|
||||
# ee_action_space_params: # If null then ee_action_space is not used
|
||||
# bounds:
|
||||
# max: [0.291, 0.147, 0.074]
|
||||
# min: [0.139, -0.143, 0.03]
|
||||
|
||||
# Bounds for insertcube_gamepad dataset and experiments
|
||||
fixed_reset_joint_positions: [20.0, 90., 90., 75., -0.7910156, -0.5673759]
|
||||
ee_action_space_params:
|
||||
bounds:
|
||||
max: [0.25295413, 0.07498981, 0.06862044]
|
||||
min: [0.2010096, -0.12, 0.0433196]
|
||||
|
||||
use_gamepad: true
|
||||
x_step_size: 0.03
|
||||
y_step_size: 0.03
|
||||
z_step_size: 0.03
|
||||
|
||||
reward_classifier:
|
||||
pretrained_path: null # outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
|
||||
config_path: null # lerobot/configs/policy/hilserl_classifier.yaml
|
||||
|
||||
@@ -114,7 +114,7 @@ policy:
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
|
||||
@@ -95,7 +95,7 @@ policy:
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
|
||||
@@ -95,7 +95,7 @@ policy:
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
|
||||
@@ -95,7 +95,7 @@ policy:
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
hydra:
|
||||
run:
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
dir: outputs/train_hilserl_classifier/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${hydra.job.name}
|
||||
job:
|
||||
name: default
|
||||
|
||||
seed: 13
|
||||
dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized
|
||||
# aractingi/push_cube_square_reward_1_cropped_resized
|
||||
dataset_root: data/aractingi/push_cube_square_light_reward_cropped_resized
|
||||
local_files_only: true
|
||||
train_split_proportion: 0.8
|
||||
|
||||
# Required by logger
|
||||
env:
|
||||
name: "classifier"
|
||||
task: "binary_classification"
|
||||
|
||||
|
||||
training:
|
||||
num_epochs: 6
|
||||
batch_size: 16
|
||||
learning_rate: 1e-4
|
||||
num_workers: 4
|
||||
grad_clip_norm: 10
|
||||
use_amp: true
|
||||
log_freq: 1
|
||||
eval_freq: 1 # How often to run validation (in epochs)
|
||||
save_freq: 1 # How often to save checkpoints (in epochs)
|
||||
save_checkpoint: true
|
||||
image_keys: ["observation.images.front", "observation.images.side"]
|
||||
label_key: "next.reward"
|
||||
profile_inference_time: false
|
||||
profile_inference_time_iters: 20
|
||||
|
||||
eval:
|
||||
batch_size: 16
|
||||
num_samples_to_log: 30 # Number of validation samples to log in the table
|
||||
|
||||
policy:
|
||||
name: "hilserl/classifier"
|
||||
model_name: "helper2424/resnet10" # "facebook/convnext-base-224
|
||||
model_type: "cnn"
|
||||
num_cameras: 2 # Has to be len(training.image_keys)
|
||||
|
||||
wandb:
|
||||
enable: false
|
||||
project: "classifier-training"
|
||||
job_name: "classifier_training_0"
|
||||
disable_artifact: false
|
||||
|
||||
device: "mps"
|
||||
resume: false
|
||||
output_dir: "outputs/classifier/old_trainer_resnet10_frozen"
|
||||
@@ -1,118 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Train with:
|
||||
#
|
||||
# python lerobot/scripts/train.py \
|
||||
# +dataset=lerobot/pusht_keypoints
|
||||
# env=pusht \
|
||||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
|
||||
seed: 1
|
||||
# dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium"
|
||||
dataset_repo_id: null
|
||||
|
||||
training:
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
batch_size: 512
|
||||
grad_clip_norm: 40.0
|
||||
lr: 3e-4
|
||||
|
||||
|
||||
storage_device: "cuda"
|
||||
|
||||
eval_freq: 2500
|
||||
log_freq: 10
|
||||
save_freq: 1000000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
online_rollout_batch_size: 10
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 200000
|
||||
offline_buffer_capacity: 100000
|
||||
online_buffer_seed_size: 0
|
||||
online_step_before_learning: 500
|
||||
do_online_rollout_async: false
|
||||
policy_update_freq: 1
|
||||
|
||||
policy:
|
||||
name: sac
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 1
|
||||
n_action_steps: 1
|
||||
|
||||
shared_encoder: true
|
||||
# vision_encoder_name: "helper2424/resnet10"
|
||||
vision_encoder_name: null
|
||||
# freeze_vision_encoder: true
|
||||
freeze_vision_encoder: false
|
||||
input_shapes:
|
||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.state: ["${env.state_dim}"]
|
||||
observation.image: [3, 64, 64]
|
||||
output_shapes:
|
||||
action: [7]
|
||||
|
||||
camera_number: 1
|
||||
|
||||
# Normalization / Unnormalization
|
||||
# input_normalization_modes: null
|
||||
input_normalization_modes:
|
||||
observation.state: min_max
|
||||
observation.image: mean_std
|
||||
# input_normalization_params: null
|
||||
input_normalization_params:
|
||||
observation.state:
|
||||
min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
|
||||
1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
|
||||
-3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
|
||||
-6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
|
||||
8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
|
||||
max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
|
||||
0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
|
||||
7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
|
||||
0.4001]
|
||||
|
||||
observation.image:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
output_normalization_params:
|
||||
action:
|
||||
min: [-0.03, -0.03, -0.03, -0.03, -0.03, -0.03, -0.03]
|
||||
max: [0.03, 0.03, 0.03, 0.03, 0.03, 0.03, 0.03]
|
||||
output_normalization_shapes:
|
||||
action: [7]
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
image_encoder_hidden_dim: 32
|
||||
# discount: 0.99
|
||||
discount: 0.80
|
||||
temperature_init: 1.0
|
||||
num_critics: 2 #10
|
||||
num_subsample_critics: null
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
# critic_target_update_weight: 0.005
|
||||
critic_target_update_weight: 0.01
|
||||
utd_ratio: 2 # 10
|
||||
|
||||
actor_learner_config:
|
||||
learner_host: "127.0.0.1"
|
||||
learner_port: 50051
|
||||
policy_parameters_push_frequency: 4
|
||||
concurrency:
|
||||
actor: 'threads'
|
||||
learner: 'threads'
|
||||
@@ -1,89 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Train with:
|
||||
#
|
||||
# python lerobot/scripts/train.py \
|
||||
# env=pusht \
|
||||
# +dataset=lerobot/pusht_keypoints
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: lerobot/pusht_keypoints
|
||||
|
||||
training:
|
||||
offline_steps: 0
|
||||
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
batch_size: 128
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 50000
|
||||
log_freq: 500
|
||||
save_freq: 50000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
online_rollout_batch_size: 10
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 40000
|
||||
online_buffer_seed_size: 0
|
||||
do_online_rollout_async: false
|
||||
|
||||
delta_timestamps:
|
||||
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: sac
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 5
|
||||
n_action_steps: 5
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.environment_state: [16]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.environment_state: min_max
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
# image_encoder_hidden_dim: 32
|
||||
discount: 0.99
|
||||
temperature_init: 1.0
|
||||
num_critics: 2
|
||||
num_subsample_critics: None
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
critic_target_update_weight: 0.005
|
||||
utd_ratio: 2
|
||||
|
||||
|
||||
# # Loss coefficients.
|
||||
# reward_coeff: 0.5
|
||||
# expectile_weight: 0.9
|
||||
# value_coeff: 0.1
|
||||
# consistency_coeff: 20.0
|
||||
# advantage_scaling: 3.0
|
||||
# pi_coeff: 0.5
|
||||
# temporal_decay_coeff: 0.5
|
||||
# # Target model.
|
||||
# target_model_momentum: 0.995
|
||||
@@ -1,120 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Train with:
|
||||
#
|
||||
# python lerobot/scripts/train.py \
|
||||
# +dataset=lerobot/pusht_keypoints
|
||||
# env=pusht \
|
||||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: aractingi/insertcube_simple
|
||||
|
||||
training:
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
# batch_size: 256
|
||||
batch_size: 512
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 2500
|
||||
log_freq: 1
|
||||
save_freq: 2000000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
online_rollout_batch_size: 10
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 10000
|
||||
online_buffer_seed_size: 0
|
||||
online_step_before_learning: 100 #5000
|
||||
do_online_rollout_async: false
|
||||
policy_update_freq: 1
|
||||
|
||||
# delta_timestamps:
|
||||
# observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
# observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
# action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
# next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: sac
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 1
|
||||
n_action_steps: 1
|
||||
|
||||
shared_encoder: true
|
||||
vision_encoder_name: "helper2424/resnet10"
|
||||
freeze_vision_encoder: true
|
||||
input_shapes:
|
||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.state: ["${env.state_dim}"]
|
||||
observation.images.front: [3, 128, 128]
|
||||
observation.images.side: [3, 128, 128]
|
||||
# observation.image: [3, 128, 128]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.front: mean_std
|
||||
observation.images.side: mean_std
|
||||
observation.state: min_max
|
||||
input_normalization_params:
|
||||
observation.images.front:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
observation.images.side:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
observation.state:
|
||||
# 6- joint positions, 6- joint velocities, 3- ee position
|
||||
max: [ 52.822266, 136.14258, 142.03125, 72.1582, 22.675781, -0.5673759, 100., 100., 100., 100., 100., 100., 0.25295413, 0.07498981, 0.06862044]
|
||||
min: [-2.6367188, 86.572266, 89.82422, 12.392578, -26.015625, -0.5673759, -100., -100., -100., -100., -100., -100., 0.2010096, -0.12, 0.0433196]
|
||||
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
output_normalization_params:
|
||||
action:
|
||||
min: [-0.03, -0.03, -0.01]
|
||||
max: [0.03, 0.03, 0.03]
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
image_encoder_hidden_dim: 32
|
||||
# discount: 0.99
|
||||
discount: 0.97
|
||||
temperature_init: 1.0
|
||||
num_critics: 2 #10
|
||||
camera_number: 2
|
||||
num_subsample_critics: null
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
# critic_target_update_weight: 0.005
|
||||
critic_target_update_weight: 0.01
|
||||
utd_ratio: 2 # 10
|
||||
|
||||
actor_learner_config:
|
||||
learner_host: "127.0.0.1"
|
||||
learner_port: 50051
|
||||
policy_parameters_push_frequency: 15
|
||||
|
||||
# # Loss coefficients.
|
||||
# reward_coeff: 0.5
|
||||
# expectile_weight: 0.9
|
||||
# value_coeff: 0.1
|
||||
# consistency_coeff: 20.0
|
||||
# advantage_scaling: 3.0
|
||||
# pi_coeff: 0.5
|
||||
# temporal_decay_coeff: 0.5
|
||||
# # Target model.
|
||||
# target_model_momentum: 0.995
|
||||
@@ -10,7 +10,7 @@ max_relative_target: null
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem58760430441
|
||||
port: /dev/tty.usbmodem575E0031751
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl330-m077"]
|
||||
@@ -23,7 +23,7 @@ leader_arms:
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem585A0083391
|
||||
port: /dev/tty.usbmodem575E0032081
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl430-w250"]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user