Compare commits
25 Commits
user/youli
...
Cadene-pat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4aef34c8e | ||
|
|
f98200297d | ||
|
|
86bbd16d43 | ||
|
|
0f6e0f6d74 | ||
|
|
fc3e545e03 | ||
|
|
b98ea415c1 | ||
|
|
bbe9057225 | ||
|
|
8c4643687c | ||
|
|
fab037f78d | ||
|
|
03d647269e | ||
|
|
2252b42337 | ||
|
|
bc6384bb80 | ||
|
|
8df7e63d61 | ||
|
|
7a3cb1ad34 | ||
|
|
f8a6574698 | ||
|
|
abbb1d2367 | ||
|
|
0b21210d72 | ||
|
|
461d5472d3 | ||
|
|
c75ea789a8 | ||
|
|
ee200e86cb | ||
|
|
8865e19c12 | ||
|
|
5f5efe7cb9 | ||
|
|
c0101f0948 | ||
|
|
5e54e39795 | ||
|
|
5ffcb48a9a |
52
.github/workflows/build-docker-images.yml
vendored
@@ -14,20 +14,14 @@ env:
|
||||
jobs:
|
||||
latest-cpu:
|
||||
name: CPU
|
||||
runs-on: ubuntu-latest
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
- name: Install Git LFS
|
||||
run: |
|
||||
sudo df -h
|
||||
# sudo ls -l /usr/local/lib/
|
||||
# sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo df -h
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -55,20 +49,15 @@ jobs:
|
||||
|
||||
latest-cuda:
|
||||
name: GPU
|
||||
runs-on: ubuntu-latest
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
- name: Install Git LFS
|
||||
run: |
|
||||
sudo df -h
|
||||
# sudo ls -l /usr/local/lib/
|
||||
# sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo df -h
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
@@ -95,20 +84,9 @@ jobs:
|
||||
|
||||
latest-cuda-dev:
|
||||
name: GPU Dev
|
||||
runs-on: ubuntu-latest
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo df -h
|
||||
# sudo ls -l /usr/local/lib/
|
||||
# sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo df -h
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
6
.github/workflows/nightly-tests.yml
vendored
@@ -16,7 +16,8 @@ jobs:
|
||||
name: CPU
|
||||
strategy:
|
||||
fail-fast: false
|
||||
runs-on: ubuntu-latest
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
container:
|
||||
image: huggingface/lerobot-cpu:latest
|
||||
options: --shm-size "16gb"
|
||||
@@ -43,7 +44,8 @@ jobs:
|
||||
name: GPU
|
||||
strategy:
|
||||
fail-fast: false
|
||||
runs-on: [single-gpu, nvidia-gpu, t4, ci]
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
TEST_TYPE: "single_gpu"
|
||||
|
||||
28
.github/workflows/quality.yml
vendored
@@ -54,3 +54,31 @@ jobs:
|
||||
|
||||
- name: Poetry check
|
||||
run: poetry check
|
||||
|
||||
|
||||
poetry_relax:
|
||||
name: Poetry relax
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install poetry
|
||||
|
||||
- name: Install poetry-relax
|
||||
run: poetry self add poetry-relax
|
||||
|
||||
- name: Poetry relax
|
||||
id: poetry_relax
|
||||
run: |
|
||||
output=$(poetry relax --check 2>&1)
|
||||
if echo "$output" | grep -q "Proposing updates"; then
|
||||
echo "$output"
|
||||
echo ""
|
||||
echo "Some dependencies have caret '^' version requirement added by poetry by default."
|
||||
echo "Please replace them with '>='. You can do this by hand or use poetry-relax to do this."
|
||||
exit 1
|
||||
else
|
||||
echo "$output"
|
||||
fi
|
||||
|
||||
16
.github/workflows/test-docker-build.yml
vendored
@@ -42,26 +42,14 @@ jobs:
|
||||
build_modified_dockerfiles:
|
||||
name: Build modified Docker images
|
||||
needs: get_changed_files
|
||||
runs-on: ubuntu-latest
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: ${{ needs.get_changed_files.outputs.matrix }} != ''
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo df -h
|
||||
# sudo ls -l /usr/local/lib/
|
||||
# sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
2
.github/workflows/trufflehog.yml
vendored
@@ -16,3 +16,5 @@ jobs:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
with:
|
||||
extra_args: --only-verified
|
||||
|
||||
2
.gitignore
vendored
@@ -121,8 +121,8 @@ celerybeat.pid
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
|
||||
@@ -14,11 +14,11 @@ repos:
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.15.2
|
||||
rev: v3.16.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.4.3
|
||||
rev: v0.5.2
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
@@ -31,3 +31,7 @@ repos:
|
||||
args:
|
||||
- "--check"
|
||||
- "--no-update"
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.18.4
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
24
Makefile
@@ -26,6 +26,7 @@ test-end-to-end:
|
||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train-with-online
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-default-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial
|
||||
@@ -113,7 +114,6 @@ test-diffusion-ete-eval:
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
|
||||
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
|
||||
test-tdmpc-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=tdmpc \
|
||||
@@ -133,6 +133,28 @@ test-tdmpc-ete-train:
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-train-with-online:
|
||||
python lerobot/scripts/train.py \
|
||||
env=pusht \
|
||||
env.gym.obs_type=environment_state_agent_pos \
|
||||
policy=tdmpc_pusht_keypoints \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=10 \
|
||||
device=$(DEVICE) \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=20 \
|
||||
training.save_checkpoint=false \
|
||||
training.save_freq=10 \
|
||||
training.batch_size=2 \
|
||||
training.online_rollout_n_episodes=2 \
|
||||
training.online_rollout_batch_size=2 \
|
||||
training.online_steps_between_rollouts=10 \
|
||||
training.online_buffer_capacity=15 \
|
||||
eval.use_async_envs=true \
|
||||
hydra.run.dir=tests/outputs/tdmpc_online/
|
||||
|
||||
|
||||
test-tdmpc-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
|
||||
31
README.md
@@ -22,8 +22,21 @@
|
||||
|
||||
</div>
|
||||
|
||||
<h2 align="center">
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md">Hot new tutorial: Getting started with real-world robots</a></p>
|
||||
</h2>
|
||||
|
||||
<div align="center">
|
||||
<img src="media/tutorial/koch_v1_1_leader_follower.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="50%">
|
||||
<p>We just dropped an in-depth tutorial on how to build your own robot!</p>
|
||||
<p>Teach it new skills by showing it a few moves with just a laptop.</p>
|
||||
<p>Then watch your homemade robot act autonomously 🤯</p>
|
||||
<p>For more info, see <a href="https://x.com/RemiCadene/status/1825455895561859185">our thread on X</a> or <a href="https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md">our tutorial page</a>.</p>
|
||||
</div>
|
||||
|
||||
|
||||
<h3 align="center">
|
||||
<p>State-of-the-art Machine Learning for real-world robotics</p>
|
||||
<p>State-of-the-art AI for real-world robotics</p>
|
||||
</h3>
|
||||
|
||||
---
|
||||
@@ -65,17 +78,19 @@
|
||||
|
||||
Download our source code:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git && cd lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
```
|
||||
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
pip install .
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
> **NOTE:** Depending on your platform, If you encounter any build errors during this step
|
||||
@@ -89,7 +104,7 @@ For simulations, 🤗 LeRobot comes with gymnasium environments that can be inst
|
||||
|
||||
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
||||
```bash
|
||||
pip install ".[aloha, pusht]"
|
||||
pip install -e ".[aloha, pusht]"
|
||||
```
|
||||
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||
@@ -114,10 +129,12 @@ wandb login
|
||||
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
|
||||
| | ├── envs # various sim environments: aloha, pusht, xarm
|
||||
| | ├── policies # various policies: act, diffusion, tdmpc
|
||||
| | ├── robot_devices # various real devices: dynamixel motors, opencv cameras, koch robots
|
||||
| | └── utils # various utilities
|
||||
| └── scripts # contains functions to execute via command line
|
||||
| ├── eval.py # load policy and evaluate it on an environment
|
||||
| ├── train.py # train a policy via imitation learning and/or reinforcement learning
|
||||
| ├── control_robot.py # teleoperate a real robot, record data, run a policy
|
||||
| ├── push_dataset_to_hub.py # convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub
|
||||
| └── visualize_dataset.py # load a dataset and render its demonstrations
|
||||
├── outputs # contains results of scripts execution: logs, videos, model checkpoints
|
||||
@@ -180,8 +197,10 @@ dataset attributes:
|
||||
│ ├ observation.images.cam_high: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
|
||||
│ ...
|
||||
├ info: a dictionary of metadata on the dataset
|
||||
│ ├ codebase_version (str): this is to keep track of the codebase version the dataset was created with
|
||||
│ ├ fps (float): frame per second the dataset is recorded/synchronized to
|
||||
│ └ video (bool): indicates if frames are encoded in mp4 video files to save space or stored as png files
|
||||
│ ├ video (bool): indicates if frames are encoded in mp4 video files to save space or stored as png files
|
||||
│ └ encoding (dict): if video, this documents the main options that were used with ffmpeg to encode the videos
|
||||
├ videos_dir (Path): where the mp4 videos or png images are stored/accessed
|
||||
└ camera_keys (list of string): the keys to access camera features in the item returned by the dataset (e.g. `["observation.images.cam_high", ...]`)
|
||||
```
|
||||
|
||||
@@ -257,10 +257,10 @@ def benchmark_encoding_decoding(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
video_codec=encoding_cfg["vcodec"],
|
||||
pixel_format=encoding_cfg["pix_fmt"],
|
||||
group_of_pictures_size=encoding_cfg.get("g"),
|
||||
constant_rate_factor=encoding_cfg.get("crf"),
|
||||
vcodec=encoding_cfg["vcodec"],
|
||||
pix_fmt=encoding_cfg["pix_fmt"],
|
||||
g=encoding_cfg.get("g"),
|
||||
crf=encoding_cfg.get("crf"),
|
||||
# fast_decode=encoding_cfg.get("fastdecode"),
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create virtual environment
|
||||
@@ -21,7 +22,7 @@ RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]" \
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, koch]" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
|
||||
@@ -13,6 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
sed gawk grep curl wget zip unzip \
|
||||
tcpdump sysstat screen tmux \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
speech-dispatcher \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@@ -43,7 +44,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libsvtav1-dev libsvtav1enc-dev libsvtav1dec-dev \
|
||||
libdav1d-dev
|
||||
|
||||
|
||||
# Install gh cli tool
|
||||
RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \
|
||||
&& mkdir -p -m 755 /etc/apt/keyrings \
|
||||
|
||||
@@ -9,7 +9,8 @@ ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
speech-dispatcher \
|
||||
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -23,7 +24,7 @@ RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]"
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, koch]"
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
|
||||
@@ -18,8 +18,6 @@ from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
output_directory = Path("outputs/eval/example_pusht_diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Download the diffusion policy for pusht environment
|
||||
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
@@ -27,6 +25,17 @@ pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||
|
||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||
policy.eval()
|
||||
|
||||
# Check if GPU is available
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
print("GPU is available. Device set to:", device)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print(f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU.")
|
||||
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
|
||||
policy.diffusion.num_inference_steps = 10
|
||||
|
||||
policy.to(device)
|
||||
|
||||
# Initialize evaluation environment to render two observation types:
|
||||
|
||||
1005
examples/7_get_started_with_real_robot.md
Normal file
@@ -80,7 +80,7 @@ policy:
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
|
||||
@@ -125,6 +125,10 @@ available_real_world_datasets = [
|
||||
"lerobot/aloha_static_vinh_cup_left",
|
||||
"lerobot/aloha_static_ziploc_slide",
|
||||
"lerobot/umi_cup_in_the_wild",
|
||||
"lerobot/unitreeh1_fold_clothes",
|
||||
"lerobot/unitreeh1_rearrange_objects",
|
||||
"lerobot/unitreeh1_two_robot_greeting",
|
||||
"lerobot/unitreeh1_warehouse",
|
||||
]
|
||||
|
||||
available_datasets = list(
|
||||
|
||||
@@ -35,15 +35,15 @@ from lerobot.common.datasets.utils import (
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||
|
||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||
CODEBASE_VERSION = "v1.6"
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
CODEBASE_VERSION = "v1.5"
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
image_transforms: Callable | None = None,
|
||||
@@ -52,7 +52,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.image_transforms = image_transforms
|
||||
@@ -60,16 +59,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# load data from hub or locally when root is provided
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
|
||||
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
|
||||
if split == "train":
|
||||
self.episode_data_index = load_episode_data_index(repo_id, version, root)
|
||||
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
|
||||
else:
|
||||
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
||||
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
||||
self.stats = load_stats(repo_id, version, root)
|
||||
self.info = load_info(repo_id, version, root)
|
||||
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
|
||||
self.info = load_info(repo_id, CODEBASE_VERSION, root)
|
||||
if self.video:
|
||||
self.videos_dir = load_videos(repo_id, version, root)
|
||||
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
|
||||
@property
|
||||
@@ -164,7 +163,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository ID: '{self.repo_id}',\n"
|
||||
f" Version: '{self.version}',\n"
|
||||
f" Split: '{self.split}',\n"
|
||||
f" Number of Samples: {self.num_samples},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
@@ -173,6 +171,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
@@ -180,7 +179,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def from_preloaded(
|
||||
cls,
|
||||
repo_id: str = "from_preloaded",
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
@@ -204,7 +202,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# create an empty object of type LeRobotDataset
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.version = version
|
||||
obj.root = root
|
||||
obj.split = split
|
||||
obj.image_transforms = transform
|
||||
@@ -228,7 +225,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
image_transforms: Callable | None = None,
|
||||
@@ -242,7 +238,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
self._datasets = [
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
version=version,
|
||||
root=root,
|
||||
split=split,
|
||||
delta_timestamps=delta_timestamps,
|
||||
@@ -279,7 +274,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
self.disabled_data_keys.update(extra_keys)
|
||||
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.image_transforms = image_transforms
|
||||
@@ -395,7 +389,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository IDs: '{self.repo_ids}',\n"
|
||||
f" Version: '{self.version}',\n"
|
||||
f" Split: '{self.split}',\n"
|
||||
f" Number of Samples: {self.num_samples},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
|
||||
384
lerobot/common/datasets/online_buffer.py
Normal file
@@ -0,0 +1,384 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""An online buffer for the online training loop in train.py
|
||||
|
||||
Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should
|
||||
consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much
|
||||
faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it
|
||||
supports in-place slicing and mutation which is very handy for a dynamic buffer.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def _make_memmap_safe(**kwargs) -> np.memmap:
|
||||
"""Make a numpy memmap with checks on available disk space first.
|
||||
|
||||
Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape"
|
||||
|
||||
For information on dtypes:
|
||||
https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing
|
||||
"""
|
||||
if kwargs["mode"].startswith("w"):
|
||||
required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes
|
||||
stats = os.statvfs(Path(kwargs["filename"]).parent)
|
||||
available_space = stats.f_bavail * stats.f_frsize # bytes
|
||||
if required_space >= available_space * 0.8:
|
||||
raise RuntimeError(
|
||||
f"You're about to take up {required_space} of {available_space} bytes available."
|
||||
)
|
||||
return np.memmap(**kwargs)
|
||||
|
||||
|
||||
class OnlineBuffer(torch.utils.data.Dataset):
|
||||
"""FIFO data buffer for the online training loop in train.py.
|
||||
|
||||
Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training
|
||||
loop in the same way that a LeRobotDataset would be used.
|
||||
|
||||
The underlying data structure will have data inserted in a circular fashion. Always insert after the
|
||||
last index, and when you reach the end, wrap around to the start.
|
||||
|
||||
The data is stored in a numpy memmap.
|
||||
"""
|
||||
|
||||
NEXT_INDEX_KEY = "_next_index"
|
||||
OCCUPANCY_MASK_KEY = "_occupancy_mask"
|
||||
INDEX_KEY = "index"
|
||||
FRAME_INDEX_KEY = "frame_index"
|
||||
EPISODE_INDEX_KEY = "episode_index"
|
||||
TIMESTAMP_KEY = "timestamp"
|
||||
IS_PAD_POSTFIX = "_is_pad"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
write_dir: str | Path,
|
||||
data_spec: dict[str, Any] | None,
|
||||
buffer_capacity: int | None,
|
||||
fps: float | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None,
|
||||
):
|
||||
"""
|
||||
The online buffer can be provided from scratch or you can load an existing online buffer by passing
|
||||
a `write_dir` associated with an existing buffer.
|
||||
|
||||
Args:
|
||||
write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key.
|
||||
Note that if the files already exist, they are opened in read-write mode (used for training
|
||||
resumption.)
|
||||
data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int],
|
||||
"dtype": np.dtype}}. This should include all the data that you wish to record into the buffer,
|
||||
but note that "index", "frame_index" and "episode_index" are already accounted for by this
|
||||
class, so you don't need to include them.
|
||||
buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your
|
||||
system's available disk space when choosing this.
|
||||
fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the
|
||||
delta_timestamps logic. You can pass None if you are not using delta_timestamps.
|
||||
delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally
|
||||
converted to dict[str, np.ndarray] for optimization purposes.
|
||||
|
||||
"""
|
||||
self.set_delta_timestamps(delta_timestamps)
|
||||
self._fps = fps
|
||||
# Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from
|
||||
# the requested frames. It is only used when `delta_timestamps` is provided.
|
||||
# minus 1e-4 to account for possible numerical error
|
||||
self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None
|
||||
self._buffer_capacity = buffer_capacity
|
||||
data_spec = self._make_data_spec(data_spec, buffer_capacity)
|
||||
Path(write_dir).mkdir(parents=True, exist_ok=True)
|
||||
self._data = {}
|
||||
for k, v in data_spec.items():
|
||||
self._data[k] = _make_memmap_safe(
|
||||
filename=Path(write_dir) / k,
|
||||
dtype=v["dtype"] if v is not None else None,
|
||||
mode="r+" if (Path(write_dir) / k).exists() else "w+",
|
||||
shape=tuple(v["shape"]) if v is not None else None,
|
||||
)
|
||||
|
||||
@property
|
||||
def delta_timestamps(self) -> dict[str, np.ndarray] | None:
|
||||
return self._delta_timestamps
|
||||
|
||||
def set_delta_timestamps(self, value: dict[str, list[float]] | None):
|
||||
"""Set delta_timestamps converting the values to numpy arrays.
|
||||
|
||||
The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays
|
||||
need to be converted into numpy arrays.
|
||||
"""
|
||||
if value is not None:
|
||||
self._delta_timestamps = {k: np.array(v) for k, v in value.items()}
|
||||
else:
|
||||
self._delta_timestamps = None
|
||||
|
||||
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
|
||||
"""Makes the data spec for np.memmap."""
|
||||
if any(k.startswith("_") for k in data_spec):
|
||||
raise ValueError(
|
||||
"data_spec keys should not start with '_'. This prefix is reserved for internal logic."
|
||||
)
|
||||
preset_keys = {
|
||||
OnlineBuffer.INDEX_KEY,
|
||||
OnlineBuffer.FRAME_INDEX_KEY,
|
||||
OnlineBuffer.EPISODE_INDEX_KEY,
|
||||
OnlineBuffer.TIMESTAMP_KEY,
|
||||
}
|
||||
if len(intersection := set(data_spec).intersection(preset_keys)) > 0:
|
||||
raise ValueError(
|
||||
f"data_spec should not contain any of {preset_keys} as these are handled internally. "
|
||||
f"The provided data_spec has {intersection}."
|
||||
)
|
||||
complete_data_spec = {
|
||||
# _next_index will be a pointer to the next index that we should start filling from when we add
|
||||
# more data.
|
||||
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
|
||||
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
|
||||
# with real data rather than the dummy initialization.
|
||||
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
|
||||
}
|
||||
for k, v in data_spec.items():
|
||||
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
|
||||
return complete_data_spec
|
||||
|
||||
def add_data(self, data: dict[str, np.ndarray]):
|
||||
"""Add new data to the buffer, which could potentially mean shifting old data out.
|
||||
|
||||
The new data should contain all the frames (in order) of any number of episodes. The indices should
|
||||
start from 0 (note to the developer: this can easily be generalized). See the `rollout` and
|
||||
`eval_policy` functions in `eval.py` for more information on how the data is constructed.
|
||||
|
||||
Shift the incoming data index and episode_index to continue on from the last frame. Note that this
|
||||
will be done in place!
|
||||
"""
|
||||
if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0:
|
||||
raise ValueError(f"Missing data keys: {missing_keys}")
|
||||
new_data_length = len(data[self.data_keys[0]])
|
||||
if not all(len(data[k]) == new_data_length for k in self.data_keys):
|
||||
raise ValueError("All data items should have the same length")
|
||||
|
||||
next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY]
|
||||
|
||||
# Sanity check to make sure that the new data indices start from 0.
|
||||
assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0
|
||||
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
|
||||
|
||||
# Shift the incoming indices if necessary.
|
||||
if self.num_samples > 0:
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
||||
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
||||
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
||||
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
|
||||
|
||||
# Insert the new data starting from next_index. It may be necessary to wrap around to the start.
|
||||
n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index))
|
||||
for k in self.data_keys:
|
||||
if n_surplus == 0:
|
||||
slc = slice(next_index, next_index + new_data_length)
|
||||
self._data[k][slc] = data[k]
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True
|
||||
else:
|
||||
self._data[k][next_index:] = data[k][:-n_surplus]
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True
|
||||
self._data[k][:n_surplus] = data[k][-n_surplus:]
|
||||
if n_surplus == 0:
|
||||
self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length
|
||||
else:
|
||||
self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus
|
||||
|
||||
@property
|
||||
def data_keys(self) -> list[str]:
|
||||
keys = set(self._data)
|
||||
keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY)
|
||||
keys.remove(OnlineBuffer.NEXT_INDEX_KEY)
|
||||
return sorted(keys)
|
||||
|
||||
@property
|
||||
def fps(self) -> float | None:
|
||||
return self._fps
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(
|
||||
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def _item_to_tensors(self, item: dict) -> dict:
|
||||
item_ = {}
|
||||
for k, v in item.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
item_[k] = v
|
||||
elif isinstance(v, np.ndarray):
|
||||
item_[k] = torch.from_numpy(v)
|
||||
else:
|
||||
item_[k] = torch.tensor(v)
|
||||
return item_
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self) or idx < -len(self):
|
||||
raise IndexError
|
||||
|
||||
item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")}
|
||||
|
||||
if self.delta_timestamps is None:
|
||||
return self._item_to_tensors(item)
|
||||
|
||||
episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY]
|
||||
current_ts = item[OnlineBuffer.TIMESTAMP_KEY]
|
||||
episode_data_indices = np.where(
|
||||
np.bitwise_and(
|
||||
self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index,
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
|
||||
)
|
||||
)[0]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
|
||||
|
||||
for data_key in self.delta_timestamps:
|
||||
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
|
||||
# Get timestamps used as query to retrieve data of previous/future frames.
|
||||
query_ts = current_ts + self.delta_timestamps[data_key]
|
||||
|
||||
# Compute distances between each query timestamp and all timestamps of all the frames belonging to
|
||||
# the episode.
|
||||
dist = np.abs(query_ts[:, None] - episode_timestamps[None, :])
|
||||
argmin_ = np.argmin(dist, axis=1)
|
||||
min_ = dist[np.arange(dist.shape[0]), argmin_]
|
||||
|
||||
is_pad = min_ > self.tolerance_s
|
||||
|
||||
# Check violated query timestamps are all outside the episode range.
|
||||
assert (
|
||||
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
|
||||
).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
|
||||
") inside the episode range."
|
||||
)
|
||||
|
||||
# Load frames for this data key.
|
||||
item[data_key] = self._data[data_key][episode_data_indices[argmin_]]
|
||||
|
||||
item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad
|
||||
|
||||
return self._item_to_tensors(item)
|
||||
|
||||
def get_data_by_key(self, key: str) -> torch.Tensor:
|
||||
"""Returns all data for a given data key as a Tensor."""
|
||||
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
|
||||
|
||||
def compute_sampler_weights(
|
||||
offline_dataset: LeRobotDataset,
|
||||
offline_drop_n_last_frames: int = 0,
|
||||
online_dataset: OnlineBuffer | None = None,
|
||||
online_sampling_ratio: float | None = None,
|
||||
online_drop_n_last_frames: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Compute the sampling weights for the online training dataloader in train.py.
|
||||
|
||||
Args:
|
||||
offline_dataset: The LeRobotDataset used for offline pre-training.
|
||||
online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode.
|
||||
online_dataset: The OnlineBuffer used in online training.
|
||||
online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an
|
||||
online dataset is provided, this value must also be provided.
|
||||
online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online
|
||||
dataset.
|
||||
Returns:
|
||||
Tensor of weights for [offline_dataset; online_dataset], normalized to 1.
|
||||
|
||||
Notes to maintainers:
|
||||
- This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach.
|
||||
- When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace
|
||||
`EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature
|
||||
is the ability to turn shuffling off.
|
||||
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
|
||||
included here to avoid adding complexity.
|
||||
"""
|
||||
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
|
||||
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
|
||||
if (online_dataset is None) ^ (online_sampling_ratio is None):
|
||||
raise ValueError(
|
||||
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
|
||||
)
|
||||
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
|
||||
weights = []
|
||||
|
||||
if len(offline_dataset) > 0:
|
||||
offline_data_mask_indices = []
|
||||
for start_index, end_index in zip(
|
||||
offline_dataset.episode_data_index["from"],
|
||||
offline_dataset.episode_data_index["to"],
|
||||
strict=True,
|
||||
):
|
||||
offline_data_mask_indices.extend(
|
||||
range(start_index.item(), end_index.item() - offline_drop_n_last_frames)
|
||||
)
|
||||
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
|
||||
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
|
||||
weights.append(
|
||||
torch.full(
|
||||
size=(len(offline_dataset),),
|
||||
fill_value=offline_sampling_ratio / offline_data_mask.sum(),
|
||||
)
|
||||
* offline_data_mask
|
||||
)
|
||||
|
||||
if online_dataset is not None and len(online_dataset) > 0:
|
||||
online_data_mask_indices = []
|
||||
episode_indices = online_dataset.get_data_by_key("episode_index")
|
||||
for episode_idx in torch.unique(episode_indices):
|
||||
where_episode = torch.where(episode_indices == episode_idx)
|
||||
start_index = where_episode[0][0]
|
||||
end_index = where_episode[0][-1] + 1
|
||||
online_data_mask_indices.extend(
|
||||
range(start_index.item(), end_index.item() - online_drop_n_last_frames)
|
||||
)
|
||||
online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool)
|
||||
online_data_mask[torch.tensor(online_data_mask_indices)] = True
|
||||
weights.append(
|
||||
torch.full(
|
||||
size=(len(online_dataset),),
|
||||
fill_value=online_sampling_ratio / online_data_mask.sum(),
|
||||
)
|
||||
* online_data_mask
|
||||
)
|
||||
|
||||
weights = torch.cat(weights)
|
||||
|
||||
if weights.sum() == 0:
|
||||
weights += 1 / len(weights)
|
||||
else:
|
||||
weights /= weights.sum()
|
||||
|
||||
return weights
|
||||
@@ -0,0 +1,56 @@
|
||||
## Using / Updating `CODEBASE_VERSION` (for maintainers)
|
||||
|
||||
Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of
|
||||
the datasets with our code, we use a `CODEBASE_VERSION` (defined in
|
||||
lerobot/common/datasets/lerobot_dataset.py) variable.
|
||||
|
||||
For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions:
|
||||
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
|
||||
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
|
||||
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
|
||||
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
|
||||
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
|
||||
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5)
|
||||
- [v1.6](https://huggingface.co/datasets/lerobot/pusht/tree/v1.6) <-- last version
|
||||
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
|
||||
|
||||
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
|
||||
`info.json` metadata.
|
||||
|
||||
### Uploading a new dataset
|
||||
If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be
|
||||
compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your
|
||||
dataset with the current `CODEBASE_VERSION`.
|
||||
|
||||
### Updating an existing dataset
|
||||
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
|
||||
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
|
||||
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
|
||||
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
|
||||
codebase won't be affected by your change and backward compatibility is maintained.
|
||||
|
||||
However, you will need to update the version of ALL the other datasets so that they have the new
|
||||
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
|
||||
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
|
||||
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
|
||||
api = HfApi()
|
||||
|
||||
for repo_id in available_datasets:
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if CODEBASE_VERSION in branches:
|
||||
print(f"{repo_id} already @{CODEBASE_VERSION}, skipping.")
|
||||
continue
|
||||
else:
|
||||
# Now create a branch named after the new version by branching out from "main"
|
||||
# which is expected to be the preceding version
|
||||
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
|
||||
print(f"{repo_id} successfully updated @{CODEBASE_VERSION}")
|
||||
```
|
||||
@@ -19,8 +19,8 @@ This file contains download scripts for raw datasets.
|
||||
Example of usage:
|
||||
```
|
||||
python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \
|
||||
--raw-dir data/cadene/pusht_raw \
|
||||
--repo-id cadene/pusht_raw
|
||||
--raw-dir data/lerobot-raw/pusht_raw \
|
||||
--repo-id lerobot-raw/pusht_raw
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -31,86 +31,87 @@ from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
|
||||
# {raw_repo_id: raw_format}
|
||||
AVAILABLE_RAW_REPO_IDS = {
|
||||
"lerobot-raw/aloha_mobile_cabinet_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_mobile_chair_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_mobile_elevator_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_mobile_shrimp_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_mobile_wash_pan_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_mobile_wipe_wine_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_sim_insertion_human_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_sim_insertion_scripted_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_sim_transfer_cube_human_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_sim_transfer_cube_scripted_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_battery_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_candy_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_coffee_new_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_coffee_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_cups_open_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_fork_pick_up_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_pingpong_test_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_pro_pencil_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_screw_driver_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_tape_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_thread_velcro_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_towel_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_vinh_cup_left_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_vinh_cup_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_ziploc_slide_raw": "aloha_hdf5",
|
||||
"lerobot-raw/pusht_raw": "pusht_zarr",
|
||||
"lerobot-raw/umi_cup_in_the_wild_raw": "umi_zarr",
|
||||
"lerobot-raw/unitreeh1_fold_clothes_raw": "aloha_hdf5",
|
||||
"lerobot-raw/unitreeh1_rearrange_objects_raw": "aloha_hdf5",
|
||||
"lerobot-raw/unitreeh1_two_robot_greeting_raw": "aloha_hdf5",
|
||||
"lerobot-raw/unitreeh1_warehouse_raw": "aloha_hdf5",
|
||||
"lerobot-raw/xarm_lift_medium_raw": "xarm_pkl",
|
||||
"lerobot-raw/xarm_lift_medium_replay_raw": "xarm_pkl",
|
||||
"lerobot-raw/xarm_push_medium_raw": "xarm_pkl",
|
||||
"lerobot-raw/xarm_push_medium_replay_raw": "xarm_pkl",
|
||||
}
|
||||
|
||||
|
||||
def download_raw(raw_dir: Path, repo_id: str):
|
||||
# Check repo_id is well formated
|
||||
if len(repo_id.split("/")) != 2:
|
||||
raise ValueError(
|
||||
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but contains '{repo_id}'."
|
||||
)
|
||||
check_repo_id(repo_id)
|
||||
user_id, dataset_id = repo_id.split("/")
|
||||
|
||||
if not dataset_id.endswith("_raw"):
|
||||
warnings.warn(
|
||||
f"`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this naming convention by renaming your repository is advised, but not mandatory.",
|
||||
f"""`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this
|
||||
naming convention by renaming your repository is advised, but not mandatory.""",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
raw_dir = Path(raw_dir)
|
||||
# Send warning if raw_dir isn't well formated
|
||||
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
||||
warnings.warn(
|
||||
f"`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised, but not mandatory.",
|
||||
f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that
|
||||
match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised,
|
||||
but not mandatory.""",
|
||||
stacklevel=1,
|
||||
)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
snapshot_download(f"{repo_id}", repo_type="dataset", local_dir=raw_dir)
|
||||
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
|
||||
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
|
||||
|
||||
def download_all_raw_datasets():
|
||||
data_dir = Path("data")
|
||||
repo_ids = [
|
||||
"cadene/pusht_image_raw",
|
||||
"cadene/xarm_lift_medium_image_raw",
|
||||
"cadene/xarm_lift_medium_replay_image_raw",
|
||||
"cadene/xarm_push_medium_image_raw",
|
||||
"cadene/xarm_push_medium_replay_image_raw",
|
||||
"cadene/aloha_sim_insertion_human_image_raw",
|
||||
"cadene/aloha_sim_insertion_scripted_image_raw",
|
||||
"cadene/aloha_sim_transfer_cube_human_image_raw",
|
||||
"cadene/aloha_sim_transfer_cube_scripted_image_raw",
|
||||
"cadene/pusht_raw",
|
||||
"cadene/xarm_lift_medium_raw",
|
||||
"cadene/xarm_lift_medium_replay_raw",
|
||||
"cadene/xarm_push_medium_raw",
|
||||
"cadene/xarm_push_medium_replay_raw",
|
||||
"cadene/aloha_sim_insertion_human_raw",
|
||||
"cadene/aloha_sim_insertion_scripted_raw",
|
||||
"cadene/aloha_sim_transfer_cube_human_raw",
|
||||
"cadene/aloha_sim_transfer_cube_scripted_raw",
|
||||
"cadene/aloha_mobile_cabinet_raw",
|
||||
"cadene/aloha_mobile_chair_raw",
|
||||
"cadene/aloha_mobile_elevator_raw",
|
||||
"cadene/aloha_mobile_shrimp_raw",
|
||||
"cadene/aloha_mobile_wash_pan_raw",
|
||||
"cadene/aloha_mobile_wipe_wine_raw",
|
||||
"cadene/aloha_static_battery_raw",
|
||||
"cadene/aloha_static_candy_raw",
|
||||
"cadene/aloha_static_coffee_raw",
|
||||
"cadene/aloha_static_coffee_new_raw",
|
||||
"cadene/aloha_static_cups_open_raw",
|
||||
"cadene/aloha_static_fork_pick_up_raw",
|
||||
"cadene/aloha_static_pingpong_test_raw",
|
||||
"cadene/aloha_static_pro_pencil_raw",
|
||||
"cadene/aloha_static_screw_driver_raw",
|
||||
"cadene/aloha_static_tape_raw",
|
||||
"cadene/aloha_static_thread_velcro_raw",
|
||||
"cadene/aloha_static_towel_raw",
|
||||
"cadene/aloha_static_vinh_cup_raw",
|
||||
"cadene/aloha_static_vinh_cup_left_raw",
|
||||
"cadene/aloha_static_ziploc_slide_raw",
|
||||
"cadene/umi_cup_in_the_wild_raw",
|
||||
]
|
||||
for repo_id in repo_ids:
|
||||
def download_all_raw_datasets(data_dir: Path | None = None):
|
||||
if data_dir is None:
|
||||
data_dir = Path("data")
|
||||
for repo_id in AVAILABLE_RAW_REPO_IDS:
|
||||
raw_dir = data_dir / repo_id
|
||||
download_raw(raw_dir, repo_id)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(
|
||||
description=f"""A script to download raw datasets from Hugging Face hub to a local directory. Here is a
|
||||
non exhaustive list of available repositories to use in `--repo-id`: {AVAILABLE_RAW_REPO_IDS}""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
@@ -122,7 +123,8 @@ def main():
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).",
|
||||
help="""Repositery identifier on Hugging Face: a community or a user name `/` the name of
|
||||
the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).""",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
download_raw(**vars(args))
|
||||
|
||||
184
lerobot/common/datasets/push_dataset_to_hub/_encode_datasets.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Use this script to batch encode lerobot dataset from their raw format to LeRobotDataset and push their updated
|
||||
version to the hub. Under the hood, this script reuses 'push_dataset_to_hub.py'. It assumes that you already
|
||||
downloaded raw datasets, which you can do with the related '_download_raw.py' script.
|
||||
|
||||
For instance, for codebase_version = 'v1.6', the following command was run, assuming raw datasets from
|
||||
lerobot-raw were downloaded in 'raw/datasets/directory':
|
||||
```bash
|
||||
python lerobot/common/datasets/push_dataset_to_hub/_encode_datasets.py \
|
||||
--raw-dir raw/datasets/directory \
|
||||
--raw-repo-ids lerobot-raw \
|
||||
--local-dir push/datasets/directory \
|
||||
--tests-data-dir tests/data \
|
||||
--push-repo lerobot \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 2 \
|
||||
--crf 30
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import AVAILABLE_RAW_REPO_IDS
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
|
||||
|
||||
|
||||
def get_push_repo_id_from_raw(raw_repo_id: str, push_repo: str) -> str:
|
||||
dataset_id_raw = raw_repo_id.split("/")[1]
|
||||
dataset_id = dataset_id_raw.removesuffix("_raw")
|
||||
return f"{push_repo}/{dataset_id}"
|
||||
|
||||
|
||||
def encode_datasets(
|
||||
raw_dir: Path,
|
||||
raw_repo_ids: list[str],
|
||||
push_repo: str,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
local_dir: Path | None = None,
|
||||
tests_data_dir: Path | None = None,
|
||||
raw_format: str | None = None,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
if len(raw_repo_ids) == 1 and raw_repo_ids[0].lower() == "lerobot-raw":
|
||||
raw_repo_ids_format = AVAILABLE_RAW_REPO_IDS
|
||||
else:
|
||||
if raw_format is None:
|
||||
raise ValueError(raw_format)
|
||||
raw_repo_ids_format = {id_: raw_format for id_ in raw_repo_ids}
|
||||
|
||||
for raw_repo_id, repo_raw_format in raw_repo_ids_format.items():
|
||||
check_repo_id(raw_repo_id)
|
||||
dataset_repo_id_push = get_push_repo_id_from_raw(raw_repo_id, push_repo)
|
||||
dataset_raw_dir = raw_dir / raw_repo_id
|
||||
dataset_dir = local_dir / dataset_repo_id_push if local_dir is not None else None
|
||||
encoding = {
|
||||
"vcodec": vcodec,
|
||||
"pix_fmt": pix_fmt,
|
||||
"g": g,
|
||||
"crf": crf,
|
||||
}
|
||||
|
||||
if not (dataset_raw_dir).is_dir():
|
||||
raise NotADirectoryError(dataset_raw_dir)
|
||||
|
||||
if not dry_run:
|
||||
push_dataset_to_hub(
|
||||
dataset_raw_dir,
|
||||
raw_format=repo_raw_format,
|
||||
repo_id=dataset_repo_id_push,
|
||||
local_dir=dataset_dir,
|
||||
resume=True,
|
||||
encoding=encoding,
|
||||
tests_data_dir=tests_data_dir,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"DRY RUN: {dataset_raw_dir} --> {dataset_dir} --> {dataset_repo_id_push}@{CODEBASE_VERSION}"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
default=Path("data"),
|
||||
help="Directory where raw datasets are located.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--raw-repo-ids",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["lerobot-raw"],
|
||||
help="""Raw dataset repo ids. if 'lerobot-raw', the keys from `AVAILABLE_RAW_REPO_IDS` will be
|
||||
used and raw datasets will be fetched from the 'lerobot-raw/' repo and pushed with their
|
||||
associated format. It is assumed that each dataset is located at `raw_dir / raw_repo_id` """,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--raw-format",
|
||||
type=str,
|
||||
default=None,
|
||||
help="""Raw format to use for the raw repo-ids. Must be specified if --raw-repo-ids is not
|
||||
'lerobot-raw'""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="""When provided, writes the dataset converted to LeRobotDataset format in this directory
|
||||
(e.g. `data/lerobot/aloha_mobile_chair`).""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-repo",
|
||||
type=str,
|
||||
default="lerobot",
|
||||
help="Repo to upload datasets to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vcodec",
|
||||
type=str,
|
||||
default="libsvtav1",
|
||||
help="Codec to use for encoding videos",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pix-fmt",
|
||||
type=str,
|
||||
default="yuv420p",
|
||||
help="Pixel formats (chroma subsampling) to be used for encoding",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--g",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Group of pictures sizes to be used for encoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crf",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Constant rate factors to be used for encoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help=(
|
||||
"When provided, save tests artifacts into the given directory "
|
||||
"(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
type=int,
|
||||
default=0,
|
||||
help="If not set to 0, this script won't download or upload anything.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
encode_datasets(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -28,7 +28,12 @@ import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
@@ -71,7 +76,14 @@ def check_format(raw_dir) -> bool:
|
||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
# only frames from simulation are uncompressed
|
||||
compressed_images = "sim" not in raw_dir.name
|
||||
|
||||
@@ -122,7 +134,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
@@ -199,6 +211,7 @@ def from_raw_to_lerobot_format(
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
@@ -206,11 +219,15 @@ def from_raw_to_lerobot_format(
|
||||
if fps is None:
|
||||
fps = 50
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
|
||||
@@ -23,6 +23,7 @@ import torch
|
||||
from datasets import Dataset, Features, Image, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
@@ -80,8 +81,9 @@ def from_raw_to_lerobot_format(
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
if video or episodes is not None:
|
||||
if video or episodes or encoding is not None:
|
||||
# TODO(aliberts): support this
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -95,6 +97,7 @@ def from_raw_to_lerobot_format(
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
|
||||
@@ -18,12 +18,14 @@ Contains utilities to process raw data format from dora-record
|
||||
"""
|
||||
|
||||
import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
@@ -198,6 +200,7 @@ def from_raw_to_lerobot_format(
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
@@ -210,11 +213,21 @@ def from_raw_to_lerobot_format(
|
||||
if not video:
|
||||
raise NotImplementedError()
|
||||
|
||||
if encoding is not None:
|
||||
warnings.warn(
|
||||
"Video encoding is currently done outside of LeRobot for the dora_parquet format.",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
data_df = load_from_raw(raw_dir, videos_dir, fps, episodes)
|
||||
hf_dataset = to_hf_dataset(data_df, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = "unknown"
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
|
||||
@@ -25,7 +25,12 @@ import zarr
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
@@ -61,6 +66,7 @@ def load_from_raw(
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
keypoints_instead_of_image: bool = False,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
try:
|
||||
import pymunk
|
||||
@@ -171,7 +177,7 @@ def load_from_raw(
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
@@ -243,6 +249,7 @@ def from_raw_to_lerobot_format(
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
# Manually change this to True to use keypoints of the T instead of an image observation (but don't merge
|
||||
# with True). Also make sure to use video = 0 in the `push_dataset_to_hub.py` script.
|
||||
@@ -254,11 +261,15 @@ def from_raw_to_lerobot_format(
|
||||
if fps is None:
|
||||
fps = 10
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image)
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding)
|
||||
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video if not keypoints_instead_of_image else 0,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
|
||||
@@ -25,8 +25,13 @@ import zarr
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
@@ -59,7 +64,14 @@ def check_format(raw_dir) -> bool:
|
||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||
zarr_data = zarr.open(zarr_path, mode="r")
|
||||
|
||||
@@ -87,49 +99,61 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
to_ids.append(to_idx)
|
||||
from_idx = to_idx
|
||||
|
||||
ep_dicts_dir = videos_dir / "ep_dicts"
|
||||
ep_dicts_dir.mkdir(exist_ok=True, parents=True)
|
||||
ep_dicts = []
|
||||
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||
from_idx = from_ids[selected_ep_idx]
|
||||
to_idx = to_ids[selected_ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
ep_dict_path = ep_dicts_dir / f"{ep_idx}"
|
||||
if not ep_dict_path.is_file():
|
||||
from_idx = from_ids[selected_ep_idx]
|
||||
to_idx = to_ids[selected_ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
|
||||
# TODO(rcadene): save temporary images of the episode?
|
||||
# TODO(rcadene): save temporary images of the episode?
|
||||
|
||||
state = states[from_idx:to_idx]
|
||||
state = states[from_idx:to_idx]
|
||||
|
||||
ep_dict = {}
|
||||
ep_dict = {}
|
||||
|
||||
# load 57MB of images in RAM (400x224x224x3 uint8)
|
||||
imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx]
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
# load 57MB of images in RAM (400x224x224x3 uint8)
|
||||
imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx]
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
if not video_path.is_file():
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
# encode images to a mp4 video
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
|
||||
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
|
||||
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
|
||||
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
|
||||
torch.save(ep_dict, ep_dict_path)
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
ep_dict = torch.load(ep_dict_path)
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
|
||||
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
|
||||
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
|
||||
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
@@ -182,6 +206,7 @@ def from_raw_to_lerobot_format(
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
@@ -195,11 +220,15 @@ def from_raw_to_lerobot_format(
|
||||
"Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
|
||||
)
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
@@ -20,6 +21,8 @@ import numpy
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
|
||||
|
||||
def concatenate_episodes(ep_dicts):
|
||||
data_dict = {}
|
||||
@@ -51,3 +54,21 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
|
||||
num_images = len(imgs_array)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
|
||||
|
||||
def get_default_encoding() -> dict:
|
||||
"""Returns the default ffmpeg encoding parameters used by `encode_video_frames`."""
|
||||
signature = inspect.signature(encode_video_frames)
|
||||
return {
|
||||
k: v.default
|
||||
for k, v in signature.parameters.items()
|
||||
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
}
|
||||
|
||||
|
||||
def check_repo_id(repo_id: str) -> None:
|
||||
if len(repo_id.split("/")) != 2:
|
||||
raise ValueError(
|
||||
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
|
||||
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
|
||||
)
|
||||
|
||||
@@ -25,7 +25,12 @@ import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
@@ -55,7 +60,14 @@ def check_format(raw_dir):
|
||||
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
pkl_path = raw_dir / "buffer.pkl"
|
||||
|
||||
with open(pkl_path, "rb") as f:
|
||||
@@ -104,7 +116,7 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
@@ -166,6 +178,7 @@ def from_raw_to_lerobot_format(
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
@@ -173,11 +186,15 @@ def from_raw_to_lerobot_format(
|
||||
if fps is None:
|
||||
fps = 15
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
|
||||
@@ -15,17 +15,27 @@
|
||||
# limitations under the License.
|
||||
import json
|
||||
import re
|
||||
import warnings
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset, load_from_disk
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
|
||||
from PIL import Image as PILImage
|
||||
from safetensors.torch import load_file
|
||||
from torchvision import transforms
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
# Metadata will go there
|
||||
---
|
||||
This dataset was created using [🤗 LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep="/"):
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
@@ -80,7 +90,28 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
return items_dict
|
||||
|
||||
|
||||
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
||||
@cache
|
||||
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
|
||||
api = HfApi()
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if version not in branches:
|
||||
warnings.warn(
|
||||
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
|
||||
codebase. The following versions are available: {branches}.
|
||||
The requested version ('{version}') is not found. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
stacklevel=1,
|
||||
)
|
||||
if "main" not in branches:
|
||||
raise ValueError(f"Version 'main' not found on {repo_id}")
|
||||
return "main"
|
||||
else:
|
||||
return version
|
||||
|
||||
|
||||
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if root is not None:
|
||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
||||
@@ -101,7 +132,9 @@ def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
||||
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
|
||||
)
|
||||
else:
|
||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
|
||||
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
@@ -119,8 +152,9 @@ def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
|
||||
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
|
||||
)
|
||||
|
||||
return load_file(path)
|
||||
@@ -137,7 +171,10 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
||||
else:
|
||||
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
|
||||
)
|
||||
|
||||
stats = load_file(path)
|
||||
return unflatten_dict(stats)
|
||||
@@ -154,7 +191,8 @@ def load_info(repo_id, version, root) -> dict:
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "info.json"
|
||||
else:
|
||||
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version)
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
|
||||
|
||||
with open(path) as f:
|
||||
info = json.load(f)
|
||||
@@ -166,7 +204,8 @@ def load_videos(repo_id, version, root) -> Path:
|
||||
path = Path(root) / repo_id / "videos"
|
||||
else:
|
||||
# TODO(rcadene): we download the whole repo here. see if we can avoid this
|
||||
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=version)
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
|
||||
path = Path(repo_dir) / "videos"
|
||||
|
||||
return path
|
||||
@@ -354,3 +393,29 @@ def cycle(iterable):
|
||||
yield next(iterator)
|
||||
except StopIteration:
|
||||
iterator = iter(iterable)
|
||||
|
||||
|
||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None):
|
||||
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
||||
exists before creating it.
|
||||
"""
|
||||
api = HfApi()
|
||||
|
||||
branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches
|
||||
refs = [branch.ref for branch in branches]
|
||||
ref = f"refs/heads/{branch}"
|
||||
if ref in refs:
|
||||
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
|
||||
def create_lerobot_dataset_card(tags: list | None = None, text: str | None = None) -> DatasetCard:
|
||||
card = DatasetCard(DATASET_CARD_TEMPLATE)
|
||||
card.data.task_categories = ["robotics"]
|
||||
card.data.tags = ["LeRobot"]
|
||||
if tags is not None:
|
||||
card.data.tags += tags
|
||||
if text is not None:
|
||||
card.text += text
|
||||
return card
|
||||
|
||||
@@ -166,10 +166,10 @@ def encode_video_frames(
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
fps: int,
|
||||
video_codec: str = "libsvtav1",
|
||||
pixel_format: str = "yuv420p",
|
||||
group_of_pictures_size: int | None = 2,
|
||||
constant_rate_factor: int | None = 30,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
log_level: str | None = "error",
|
||||
overwrite: bool = False,
|
||||
@@ -183,20 +183,20 @@ def encode_video_frames(
|
||||
("-f", "image2"),
|
||||
("-r", str(fps)),
|
||||
("-i", str(imgs_dir / "frame_%06d.png")),
|
||||
("-vcodec", video_codec),
|
||||
("-pix_fmt", pixel_format),
|
||||
("-vcodec", vcodec),
|
||||
("-pix_fmt", pix_fmt),
|
||||
]
|
||||
)
|
||||
|
||||
if group_of_pictures_size is not None:
|
||||
ffmpeg_args["-g"] = str(group_of_pictures_size)
|
||||
if g is not None:
|
||||
ffmpeg_args["-g"] = str(g)
|
||||
|
||||
if constant_rate_factor is not None:
|
||||
ffmpeg_args["-crf"] = str(constant_rate_factor)
|
||||
if crf is not None:
|
||||
ffmpeg_args["-crf"] = str(crf)
|
||||
|
||||
if fast_decode:
|
||||
key = "-svtav1-params" if video_codec == "libsvtav1" else "-tune"
|
||||
value = f"fast-decode={fast_decode}" if video_codec == "libsvtav1" else "fastdecode"
|
||||
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
ffmpeg_args[key] = value
|
||||
|
||||
if log_level is not None:
|
||||
@@ -207,7 +207,14 @@ def encode_video_frames(
|
||||
ffmpeg_args.append("-y")
|
||||
|
||||
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
||||
subprocess.run(ffmpeg_cmd, check=True)
|
||||
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||
|
||||
if not video_path.exists():
|
||||
raise OSError(
|
||||
f"Video encoding did not work. File not found: {video_path}. "
|
||||
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -19,7 +19,7 @@ import gymnasium as gym
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv:
|
||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
"""Makes a gym vector environment according to the evaluation config.
|
||||
|
||||
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
|
||||
@@ -27,6 +27,9 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
if n_envs is not None and n_envs < 1:
|
||||
raise ValueError("`n_envs must be at least 1")
|
||||
|
||||
if cfg.env.name == "real_world":
|
||||
return
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
|
||||
try:
|
||||
|
||||
@@ -76,12 +76,10 @@ class ACTConfig:
|
||||
documentation in the policy class).
|
||||
latent_dim: The VAE's latent dimension.
|
||||
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
||||
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
|
||||
actions for a given time step over multiple policy invocations. Updates are calculated as:
|
||||
x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
|
||||
parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our
|
||||
formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we
|
||||
require `n_action_steps == 1` (since we need to query the policy every step anyway).
|
||||
temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal
|
||||
ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be
|
||||
1 when using this feature, as inference needs to happen at every step to form an ensemble. For
|
||||
more information on how ensembling works, please see `ACTTemporalEnsembler`.
|
||||
dropout: Dropout to use in the transformer layers (see code for details).
|
||||
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
||||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||||
@@ -139,7 +137,8 @@ class ACTConfig:
|
||||
n_vae_encoder_layers: int = 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: float | None = None
|
||||
# Note: the value used in ACT when temporal ensembling is enabled is 0.01.
|
||||
temporal_ensemble_coeff: float | None = None
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: float = 0.1
|
||||
@@ -151,7 +150,7 @@ class ACTConfig:
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
|
||||
if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1:
|
||||
raise NotImplementedError(
|
||||
"`n_action_steps` must be 1 when using temporal ensembling. This is "
|
||||
"because the policy needs to be queried every step to compute the ensembled action."
|
||||
|
||||
@@ -38,7 +38,13 @@ from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
|
||||
class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
class ACTPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "act"],
|
||||
):
|
||||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||
@@ -77,12 +83,15 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
|
||||
if config.temporal_ensemble_coeff is not None:
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.config.temporal_ensemble_momentum is not None:
|
||||
self._ensembled_actions = None
|
||||
if self.config.temporal_ensemble_coeff is not None:
|
||||
self.temporal_ensembler.reset()
|
||||
else:
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@@ -98,26 +107,15 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
|
||||
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
|
||||
# the first action.
|
||||
if self.config.temporal_ensemble_momentum is not None:
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
if self.config.temporal_ensemble_coeff is not None:
|
||||
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
if self._ensembled_actions is None:
|
||||
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||
# time step of the episode.
|
||||
self._ensembled_actions = actions.clone()
|
||||
else:
|
||||
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||
# the EMA update for those entries.
|
||||
alpha = self.config.temporal_ensemble_momentum
|
||||
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
|
||||
# The last action, which has no prior moving average, needs to get concatenated onto the end.
|
||||
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
|
||||
# "Consume" the first action.
|
||||
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
|
||||
action = self.temporal_ensembler.update(actions)
|
||||
return action
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
@@ -137,6 +135,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
@@ -162,6 +161,97 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
return loss_dict
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
|
||||
"""Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.
|
||||
|
||||
The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.
|
||||
They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the
|
||||
coefficient works:
|
||||
- Setting it to 0 uniformly weighs all actions.
|
||||
- Setting it positive gives more weight to older actions.
|
||||
- Setting it negative gives more weight to newer actions.
|
||||
NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This
|
||||
results in older actions being weighed more highly than newer actions (the experiments documented in
|
||||
https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be
|
||||
detrimental: doing so aggressively may diminish the benefits of action chunking).
|
||||
|
||||
Here we use an online method for computing the average rather than caching a history of actions in
|
||||
order to compute the average offline. For a simple 1D sequence it looks something like:
|
||||
|
||||
```
|
||||
import torch
|
||||
|
||||
seq = torch.linspace(8, 8.5, 100)
|
||||
print(seq)
|
||||
|
||||
m = 0.01
|
||||
exp_weights = torch.exp(-m * torch.arange(len(seq)))
|
||||
print(exp_weights)
|
||||
|
||||
# Calculate offline
|
||||
avg = (exp_weights * seq).sum() / exp_weights.sum()
|
||||
print("offline", avg)
|
||||
|
||||
# Calculate online
|
||||
for i, item in enumerate(seq):
|
||||
if i == 0:
|
||||
avg = item
|
||||
continue
|
||||
avg *= exp_weights[:i].sum()
|
||||
avg += item * exp_weights[i]
|
||||
avg /= exp_weights[:i+1].sum()
|
||||
print("online", avg)
|
||||
```
|
||||
"""
|
||||
self.chunk_size = chunk_size
|
||||
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Resets the online computation variables."""
|
||||
self.ensembled_actions = None
|
||||
# (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.
|
||||
self.ensembled_actions_count = None
|
||||
|
||||
def update(self, actions: Tensor) -> Tensor:
|
||||
"""
|
||||
Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all
|
||||
time steps, and pop/return the next batch of actions in the sequence.
|
||||
"""
|
||||
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
|
||||
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
|
||||
if self.ensembled_actions is None:
|
||||
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||
# time step of the episode.
|
||||
self.ensembled_actions = actions.clone()
|
||||
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
|
||||
# operations later.
|
||||
self.ensembled_actions_count = torch.ones(
|
||||
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
|
||||
)
|
||||
else:
|
||||
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||
# the online update for those entries.
|
||||
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
|
||||
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
|
||||
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
|
||||
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
|
||||
# The last action, which has no prior online average, needs to get concatenated onto the end.
|
||||
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
|
||||
self.ensembled_actions_count = torch.cat(
|
||||
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
|
||||
)
|
||||
# "Consume" the first action.
|
||||
action, self.ensembled_actions, self.ensembled_actions_count = (
|
||||
self.ensembled_actions[:, 0],
|
||||
self.ensembled_actions[:, 1:],
|
||||
self.ensembled_actions_count[1:],
|
||||
)
|
||||
return action
|
||||
|
||||
|
||||
class ACT(nn.Module):
|
||||
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
||||
|
||||
@@ -385,10 +475,9 @@ class ACT(nn.Module):
|
||||
if self.use_images:
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
images = batch["observation.images"]
|
||||
|
||||
for cam_index in range(images.shape[-4]):
|
||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||
for cam_index in range(batch["observation.images"].shape[-4]):
|
||||
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
|
||||
# buffer
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
|
||||
@@ -43,7 +43,13 @@ from lerobot.common.policies.utils import (
|
||||
)
|
||||
|
||||
|
||||
class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
class DiffusionPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "diffusion-policy"],
|
||||
):
|
||||
"""
|
||||
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||
@@ -111,17 +117,18 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
Schematically this looks like:
|
||||
----------------------------------------------------------------------------------------------
|
||||
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
||||
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h|
|
||||
|observation is used | YES | YES | YES | NO | NO | NO | NO | NO | NO |
|
||||
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
||||
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
||||
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||
----------------------------------------------------------------------------------------------
|
||||
Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
|
||||
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
|
||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
@@ -143,6 +150,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
|
||||
@@ -132,6 +132,7 @@ class Normalize(nn.Module):
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
@@ -197,6 +198,7 @@ class Unnormalize(nn.Module):
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
|
||||
@@ -25,12 +25,16 @@ class TDMPCConfig:
|
||||
camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift`.
|
||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
|
||||
|
||||
Args:
|
||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||
action repeats in Q-learning or ask your favorite chatbot)
|
||||
horizon: Horizon for model predictive control.
|
||||
n_action_steps: Number of action steps to take from the plan given by model predictive control. This
|
||||
is an alternative to using action repeats. If this is set to more than 1, then we require
|
||||
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
|
||||
approach of using multiple steps from the plan is not in the original implementation.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
@@ -100,6 +104,7 @@ class TDMPCConfig:
|
||||
# Input / output structure.
|
||||
n_action_repeats: int = 2
|
||||
horizon: int = 5
|
||||
n_action_steps: int = 1
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
@@ -158,17 +163,18 @@ class TDMPCConfig:
|
||||
"""Input validation (not exhaustive)."""
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if len(image_keys) != 1:
|
||||
if len(image_keys) > 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
image_key = next(iter(image_keys))
|
||||
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||
f"{self.__class__.__name__} handles at most one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
if len(image_keys) > 0:
|
||||
image_key = next(iter(image_keys))
|
||||
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||
)
|
||||
if self.n_gaussian_samples <= 0:
|
||||
raise ValueError(
|
||||
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||
@@ -179,3 +185,12 @@ class TDMPCConfig:
|
||||
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
|
||||
"information."
|
||||
)
|
||||
if self.n_action_steps > 1:
|
||||
if self.n_action_repeats != 1:
|
||||
raise ValueError(
|
||||
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
|
||||
)
|
||||
if not self.use_mpc:
|
||||
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||
|
||||
@@ -19,14 +19,10 @@
|
||||
The comments in this code may sometimes refer to these references:
|
||||
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955)
|
||||
FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029)
|
||||
|
||||
TODO(alexander-soare): Make rollout work for batch sizes larger than 1.
|
||||
TODO(alexander-soare): Use batch-first throughout.
|
||||
"""
|
||||
|
||||
# ruff: noqa: N806
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
@@ -45,7 +41,13 @@ from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
|
||||
|
||||
class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
class TDMPCPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "tdmpc"],
|
||||
):
|
||||
"""Implementation of TD-MPC learning + inference.
|
||||
|
||||
Please note several warnings for this policy.
|
||||
@@ -56,9 +58,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
process communication to use the xarm environment from FOWM. This is because our xarm
|
||||
environment uses newer dependencies and does not match the environment in FOWM. See
|
||||
https://github.com/huggingface/lerobot/pull/103 for implementation details.
|
||||
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
|
||||
- We have NOT checked that training on LeRobot reproduces the results from FOWM.
|
||||
- Nevertheless, we have verified that we can train TD-MPC for PushT. See
|
||||
`lerobot/configs/policy/tdmpc_pusht_keypoints.yaml`.
|
||||
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
|
||||
match our xarm environment.
|
||||
match our xarm environment.
|
||||
"""
|
||||
|
||||
name = "tdmpc"
|
||||
@@ -74,22 +78,6 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
logging.warning(
|
||||
"""
|
||||
Please note several warnings for this policy.
|
||||
|
||||
- Evaluation of pretrained weights created with the original FOWM code
|
||||
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
|
||||
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
|
||||
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
|
||||
process communication to use the xarm environment from FOWM. This is because our xarm
|
||||
environment uses newer dependencies and does not match the environment in FOWM. See
|
||||
https://github.com/huggingface/lerobot/pull/103 for implementation details.
|
||||
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
|
||||
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
|
||||
match our xarm environment.
|
||||
"""
|
||||
)
|
||||
|
||||
if config is None:
|
||||
config = TDMPCConfig()
|
||||
@@ -114,8 +102,14 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
assert len(image_keys) == 1
|
||||
self.input_image_key = image_keys[0]
|
||||
self._use_image = False
|
||||
self._use_env_state = False
|
||||
if len(image_keys) > 0:
|
||||
assert len(image_keys) == 1
|
||||
self._use_image = True
|
||||
self.input_image_key = image_keys[0]
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -125,10 +119,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.image": deque(maxlen=1),
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=self.config.n_action_repeats),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
}
|
||||
if self._use_image:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
if self._use_env_state:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
@@ -137,7 +134,9 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
if self._use_image:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -151,49 +150,57 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
batch[key] = batch[key][:, 0]
|
||||
|
||||
# NOTE: Order of observations matters here.
|
||||
z = self.model.encode({k: batch[k] for k in ["observation.image", "observation.state"]})
|
||||
if self.config.use_mpc:
|
||||
batch_size = batch["observation.image"].shape[0]
|
||||
# Batch processing is not handled in MPC mode, so process the batch in a loop.
|
||||
action = [] # will be a batch of actions for one step
|
||||
for i in range(batch_size):
|
||||
# Note: self.plan does not handle batches, hence the squeeze.
|
||||
action.append(self.plan(z[i]))
|
||||
action = torch.stack(action)
|
||||
encode_keys = []
|
||||
if self._use_image:
|
||||
encode_keys.append("observation.image")
|
||||
if self._use_env_state:
|
||||
encode_keys.append("observation.environment_state")
|
||||
encode_keys.append("observation.state")
|
||||
z = self.model.encode({k: batch[k] for k in encode_keys})
|
||||
if self.config.use_mpc: # noqa: SIM108
|
||||
actions = self.plan(z) # (horizon, batch, action_dim)
|
||||
else:
|
||||
# Plan with the policy (π) alone.
|
||||
action = self.model.pi(z)
|
||||
# Plan with the policy (π) alone. This always returns one action so unsqueeze to get a
|
||||
# sequence dimension like in the MPC branch.
|
||||
actions = self.model.pi(z).unsqueeze(0)
|
||||
|
||||
self.unnormalize_outputs({"action": action})["action"]
|
||||
actions = torch.clamp(actions, -1, +1)
|
||||
|
||||
for _ in range(self.config.n_action_repeats):
|
||||
self._queues["action"].append(action)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.n_action_repeats > 1:
|
||||
for _ in range(self.config.n_action_repeats):
|
||||
self._queues["action"].append(actions[0])
|
||||
else:
|
||||
# Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action.
|
||||
self._queues["action"].extend(actions[: self.config.n_action_steps])
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return torch.clamp(action, -1, 1)
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
def plan(self, z: Tensor) -> Tensor:
|
||||
"""Plan next action using TD-MPC inference.
|
||||
"""Plan sequence of actions using TD-MPC inference.
|
||||
|
||||
Args:
|
||||
z: (latent_dim,) tensor for the initial state.
|
||||
z: (batch, latent_dim,) tensor for the initial state.
|
||||
Returns:
|
||||
(action_dim,) tensor for the next action.
|
||||
|
||||
TODO(alexander-soare) Extend this to be able to work with batches.
|
||||
(horizon, batch, action_dim,) tensor for the planned trajectory of actions.
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch_size = z.shape[0]
|
||||
|
||||
# Sample Nπ trajectories from the policy.
|
||||
pi_actions = torch.empty(
|
||||
self.config.horizon,
|
||||
self.config.n_pi_samples,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
device=device,
|
||||
)
|
||||
if self.config.n_pi_samples > 0:
|
||||
_z = einops.repeat(z, "d -> n d", n=self.config.n_pi_samples)
|
||||
_z = einops.repeat(z, "b d -> n b d", n=self.config.n_pi_samples)
|
||||
for t in range(self.config.horizon):
|
||||
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
|
||||
# helpful for CEM.
|
||||
@@ -202,12 +209,14 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
|
||||
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
||||
# trajectories.
|
||||
z = einops.repeat(z, "d -> n d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||
|
||||
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
|
||||
# algorithm.
|
||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||
mean = torch.zeros(self.config.horizon, self.config.output_shapes["action"][0], device=device)
|
||||
mean = torch.zeros(
|
||||
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
|
||||
)
|
||||
# Maybe warm start CEM with the mean from the previous step.
|
||||
if self._prev_mean is not None:
|
||||
mean[:-1] = self._prev_mean[1:]
|
||||
@@ -218,6 +227,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
std_normal_noise = torch.randn(
|
||||
self.config.horizon,
|
||||
self.config.n_gaussian_samples,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
device=std.device,
|
||||
)
|
||||
@@ -226,21 +236,24 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
# Compute elite actions.
|
||||
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
|
||||
value = self.estimate_value(z, actions).nan_to_num_(0)
|
||||
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices
|
||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
|
||||
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
|
||||
# (horizon, n_elites, batch, action_dim)
|
||||
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
|
||||
|
||||
# Update guassian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
||||
max_value = elite_value.max(0)[0]
|
||||
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
||||
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
|
||||
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
|
||||
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
|
||||
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
|
||||
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
||||
score /= score.sum()
|
||||
_mean = torch.sum(einops.rearrange(score, "n -> n 1") * elite_actions, dim=1)
|
||||
score /= score.sum(axis=0, keepdim=True)
|
||||
# (horizon, batch, action_dim)
|
||||
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
|
||||
_std = torch.sqrt(
|
||||
torch.sum(
|
||||
einops.rearrange(score, "n -> n 1")
|
||||
* (elite_actions - einops.rearrange(_mean, "h d -> h 1 d")) ** 2,
|
||||
einops.rearrange(score, "n b -> n b 1")
|
||||
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
@@ -255,11 +268,9 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
|
||||
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
|
||||
# scores from the last iteration.
|
||||
actions = elite_actions[:, torch.multinomial(score, 1).item()]
|
||||
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
|
||||
|
||||
# Select only the first action
|
||||
action = actions[0]
|
||||
return action
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_value(self, z: Tensor, actions: Tensor):
|
||||
@@ -311,12 +322,17 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||
return G
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss."""
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
if self._use_image:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
@@ -326,12 +342,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
if batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"] # (t, b)
|
||||
reward = batch["next.reward"] # (t,)
|
||||
action = batch["action"] # (t, b, action_dim)
|
||||
reward = batch["next.reward"] # (t, b)
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
|
||||
# Apply random image augmentations.
|
||||
if self.config.max_random_shift_ratio > 0:
|
||||
if self._use_image and self.config.max_random_shift_ratio > 0:
|
||||
observations["observation.image"] = flatten_forward_unflatten(
|
||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
observations["observation.image"],
|
||||
@@ -343,7 +359,9 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
for k in observations:
|
||||
current_observation[k] = observations[k][0]
|
||||
next_observations[k] = observations[k][1:]
|
||||
horizon = next_observations["observation.image"].shape[0]
|
||||
horizon, batch_size = next_observations[
|
||||
"observation.image" if self._use_image else "observation.environment_state"
|
||||
].shape[:2]
|
||||
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||
@@ -413,7 +431,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
q_value_loss = (
|
||||
(
|
||||
F.mse_loss(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(
|
||||
q_preds_ensemble,
|
||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||
reduction="none",
|
||||
@@ -462,10 +481,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
|
||||
# Calculate the MSE between the actions and the action predictions.
|
||||
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
|
||||
# gaussian) and sums over the action dimension. Computing the log probability amounts to multiplying
|
||||
# the MSE by 0.5 and adding a constant offset (the log(2*pi) term) . Here we drop the constant offset
|
||||
# as it doesn't change the optimization step, and we drop the 0.5 as we instead make a configuration
|
||||
# parameter for it (see below where we compute the total loss).
|
||||
# gaussian) and sums over the action dimension. Computing the (negative) log probability amounts to
|
||||
# multiplying the MSE by 0.5 and adding a constant offset (the log(2*pi)/2 term, times the action
|
||||
# dimension). Here we drop the constant offset as it doesn't change the optimization step, and we drop
|
||||
# the 0.5 as we instead make a configuration parameter for it (see below where we compute the total
|
||||
# loss).
|
||||
mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
|
||||
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
|
||||
# other losses.
|
||||
@@ -726,6 +746,16 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
@@ -734,8 +764,11 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
over all features.
|
||||
"""
|
||||
feat = []
|
||||
# NOTE: Order of observations matters here.
|
||||
if "observation.image" in self.config.input_shapes:
|
||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
@@ -38,7 +38,13 @@ from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
|
||||
# ruff: noqa: N806
|
||||
|
||||
|
||||
class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
class VQBeTPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "vqbet"],
|
||||
):
|
||||
"""
|
||||
VQ-BeT Policy as per "Behavior Generation with Latent Actions"
|
||||
"""
|
||||
@@ -98,6 +104,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
"""
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
@@ -123,6 +130,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
@@ -287,7 +295,7 @@ class VQBeTModel(nn.Module):
|
||||
|
||||
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
||||
self.state_projector = MLP(
|
||||
config.output_shapes["action"][0], hidden_channels=[self.config.gpt_input_dim]
|
||||
config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
self.rgb_feature_projector = MLP(
|
||||
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
|
||||
|
||||
423
lerobot/common/robot_devices/cameras/opencv.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import math
|
||||
import platform
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, replace
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
from lerobot.scripts.control_robot import busy_wait
|
||||
|
||||
# Use 1 thread to avoid blocking the main thread. Especially useful during data collection
|
||||
# when other threads are used to save the images.
|
||||
cv2.setNumThreads(1)
|
||||
|
||||
# The maximum opencv device index depends on your operating system. For instance,
|
||||
# if you have 3 cameras, they should be associated to index 0, 1, and 2. This is the case
|
||||
# on MacOS. However, on Ubuntu, the indices are different like 6, 16, 23.
|
||||
# When you change the USB port or reboot the computer, the operating system might
|
||||
# treat the same cameras as new devices. Thus we select a higher bound to search indices.
|
||||
MAX_OPENCV_INDEX = 60
|
||||
|
||||
|
||||
def find_camera_indices(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX):
|
||||
if platform.system() == "Linux":
|
||||
# Linux uses camera ports
|
||||
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
||||
possible_camera_ids = []
|
||||
for port in Path("/dev").glob("video*"):
|
||||
camera_idx = int(str(port).replace("/dev/video", ""))
|
||||
possible_camera_ids.append(camera_idx)
|
||||
else:
|
||||
print(
|
||||
"Mac or Windows detected. Finding available camera indices through "
|
||||
f"scanning all indices from 0 to {MAX_OPENCV_INDEX}"
|
||||
)
|
||||
possible_camera_ids = range(max_index_search_range)
|
||||
|
||||
camera_ids = []
|
||||
for camera_idx in possible_camera_ids:
|
||||
camera = cv2.VideoCapture(camera_idx)
|
||||
is_open = camera.isOpened()
|
||||
camera.release()
|
||||
|
||||
if is_open:
|
||||
print(f"Camera found at index {camera_idx}")
|
||||
camera_ids.append(camera_idx)
|
||||
|
||||
if raise_when_empty and len(camera_ids) == 0:
|
||||
raise OSError(
|
||||
"Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, "
|
||||
"or your camera driver, or make sure your camera is compatible with opencv2."
|
||||
)
|
||||
|
||||
return camera_ids
|
||||
|
||||
|
||||
def save_image(img_array, camera_index, frame_index, images_dir):
|
||||
img = Image.fromarray(img_array)
|
||||
path = images_dir / f"camera_{camera_index:02d}_frame_{frame_index:06d}.png"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
img.save(str(path), quality=100)
|
||||
|
||||
|
||||
def save_images_from_cameras(
|
||||
images_dir: Path, camera_ids: list[int] | None = None, fps=None, width=None, height=None, record_time_s=2
|
||||
):
|
||||
if camera_ids is None:
|
||||
camera_ids = find_camera_indices()
|
||||
|
||||
print("Connecting cameras")
|
||||
cameras = []
|
||||
for cam_idx in camera_ids:
|
||||
camera = OpenCVCamera(cam_idx, fps=fps, width=width, height=height)
|
||||
camera.connect()
|
||||
print(
|
||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
|
||||
f"height={camera.height}, color_mode={camera.color_mode})"
|
||||
)
|
||||
cameras.append(camera)
|
||||
|
||||
images_dir = Path(images_dir)
|
||||
if images_dir.exists():
|
||||
shutil.rmtree(
|
||||
images_dir,
|
||||
)
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Saving images to {images_dir}")
|
||||
frame_index = 0
|
||||
start_time = time.perf_counter()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
||||
while True:
|
||||
now = time.perf_counter()
|
||||
|
||||
for camera in cameras:
|
||||
# If we use async_read when fps is None, the loop will go full speed, and we will endup
|
||||
# saving the same images from the cameras multiple times until the RAM/disk is full.
|
||||
image = camera.read() if fps is None else camera.async_read()
|
||||
|
||||
executor.submit(
|
||||
save_image,
|
||||
image,
|
||||
camera.camera_index,
|
||||
frame_index,
|
||||
images_dir,
|
||||
)
|
||||
|
||||
if fps is not None:
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
if time.perf_counter() - start_time > record_time_s:
|
||||
break
|
||||
|
||||
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
|
||||
|
||||
frame_index += 1
|
||||
|
||||
print(f"Images have been saved to {images_dir}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenCVCameraConfig:
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
OpenCVCameraConfig(30, 640, 480)
|
||||
OpenCVCameraConfig(60, 640, 480)
|
||||
OpenCVCameraConfig(90, 640, 480)
|
||||
OpenCVCameraConfig(30, 1280, 720)
|
||||
```
|
||||
"""
|
||||
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"Expected color_mode values are 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
|
||||
class OpenCVCamera:
|
||||
"""
|
||||
The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate
|
||||
with the cameras. Most cameras are compatible. For more info, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
||||
|
||||
An OpenCVCamera instance requires a camera index (e.g. `OpenCVCamera(camera_index=0)`). When you only have one camera
|
||||
like a webcam of a laptop, the camera index is expected to be 0, but it might also be very different, and the camera index
|
||||
might change if you reboot your computer or re-plug your camera. This behavior depends on your operation system.
|
||||
|
||||
To find the camera indices of your cameras, you can run our utility script that will be save a few frames for each camera:
|
||||
```bash
|
||||
python lerobot/common/robot_devices/cameras/opencv.py --images-dir outputs/images_from_opencv_cameras
|
||||
```
|
||||
|
||||
When an OpenCVCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
|
||||
of the given camera will be used.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
camera = OpenCVCamera(camera_index=0)
|
||||
camera.connect()
|
||||
color_image = camera.read()
|
||||
# when done using the camera, consider disconnecting
|
||||
camera.disconnect()
|
||||
```
|
||||
|
||||
Example of changing default fps, width, height and color_mode:
|
||||
```python
|
||||
camera = OpenCVCamera(0, fps=30, width=1280, height=720)
|
||||
camera = connect() # applies the settings, might error out if these settings are not compatible with the camera
|
||||
|
||||
camera = OpenCVCamera(0, fps=90, width=640, height=480)
|
||||
camera = connect()
|
||||
|
||||
camera = OpenCVCamera(0, fps=90, width=640, height=480, color_mode="bgr")
|
||||
camera = connect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, camera_index: int, config: OpenCVCameraConfig | None = None, **kwargs):
|
||||
if config is None:
|
||||
config = OpenCVCameraConfig()
|
||||
# Overwrite config arguments using kwargs
|
||||
config = replace(config, **kwargs)
|
||||
|
||||
self.camera_index = camera_index
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.color_mode = config.color_mode
|
||||
|
||||
self.camera = None
|
||||
self.is_connected = False
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
self.color_image = None
|
||||
self.logs = {}
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(f"Camera {self.camera_index} is already connected.")
|
||||
|
||||
# First create a temporary camera trying to access `camera_index`,
|
||||
# and verify it is a valid camera by calling `isOpened`.
|
||||
|
||||
if platform.system() == "Linux":
|
||||
# Linux uses ports for connecting to cameras
|
||||
tmp_camera = cv2.VideoCapture(f"/dev/video{self.camera_index}")
|
||||
else:
|
||||
tmp_camera = cv2.VideoCapture(self.camera_index)
|
||||
|
||||
is_camera_open = tmp_camera.isOpened()
|
||||
# Release camera to make it accessible for `find_camera_indices`
|
||||
del tmp_camera
|
||||
|
||||
# If the camera doesn't work, display the camera indices corresponding to
|
||||
# valid cameras.
|
||||
if not is_camera_open:
|
||||
# Verify that the provided `camera_index` is valid before printing the traceback
|
||||
available_cam_ids = find_camera_indices()
|
||||
if self.camera_index not in available_cam_ids:
|
||||
raise ValueError(
|
||||
f"`camera_index` is expected to be one of these available cameras {available_cam_ids}, but {self.camera_index} is provided instead. "
|
||||
"To find the camera index you should use, run `python lerobot/common/robot_devices/cameras/opencv.py`."
|
||||
)
|
||||
|
||||
raise OSError(f"Can't access camera {self.camera_index}.")
|
||||
|
||||
# Secondly, create the camera that will be used downstream.
|
||||
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
|
||||
# needs to be re-created.
|
||||
if platform.system() == "Linux":
|
||||
self.camera = cv2.VideoCapture(f"/dev/video{self.camera_index}")
|
||||
else:
|
||||
self.camera = cv2.VideoCapture(self.camera_index)
|
||||
|
||||
if self.fps is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
||||
if self.width is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
|
||||
if self.height is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
|
||||
|
||||
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
|
||||
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
|
||||
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
|
||||
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for camera {self.camera_index}. Actual value is {actual_fps}."
|
||||
)
|
||||
if self.width is not None and self.width != actual_width:
|
||||
raise OSError(
|
||||
f"Can't set {self.width=} for camera {self.camera_index}. Actual value is {actual_width}."
|
||||
)
|
||||
if self.height is not None and self.height != actual_height:
|
||||
raise OSError(
|
||||
f"Can't set {self.height=} for camera {self.camera_index}. Actual value is {actual_height}."
|
||||
)
|
||||
|
||||
self.fps = actual_fps
|
||||
self.width = actual_width
|
||||
self.height = actual_height
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
|
||||
"""Read a frame from the camera returned in the format (height, width, channels)
|
||||
(e.g. (640, 480, 3)), contrarily to the pytorch format which is channel first.
|
||||
|
||||
Note: Reading a frame is done every `camera.fps` times per second, and it is blocking.
|
||||
If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
ret, color_image = self.camera.read()
|
||||
if not ret:
|
||||
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
|
||||
|
||||
requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
|
||||
|
||||
if requested_color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
|
||||
)
|
||||
|
||||
# OpenCV uses BGR format as default (blue, green red) for all operations, including displaying images.
|
||||
# However, Deep Learning framework such as LeRobot uses RGB format as default to train neural networks,
|
||||
# so we convert the image color from BGR to RGB.
|
||||
if requested_color_mode == "rgb":
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
h, w, _ = color_image.shape
|
||||
if h != self.height or w != self.width:
|
||||
raise OSError(
|
||||
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||
)
|
||||
|
||||
# log the number of seconds it took to read the image
|
||||
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the image was received
|
||||
self.logs["timestamp_utc"] = capture_timestamp_utc()
|
||||
|
||||
return color_image
|
||||
|
||||
def read_loop(self):
|
||||
while self.stop_event is None or not self.stop_event.is_set():
|
||||
self.color_image = self.read()
|
||||
|
||||
def async_read(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
if self.thread is None:
|
||||
self.stop_event = threading.Event()
|
||||
self.thread = Thread(target=self.read_loop, args=())
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
num_tries = 0
|
||||
while self.color_image is None:
|
||||
num_tries += 1
|
||||
time.sleep(1 / self.fps)
|
||||
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
|
||||
raise Exception(
|
||||
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
|
||||
)
|
||||
|
||||
return self.color_image
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
# wait for the thread to finish
|
||||
self.stop_event.set()
|
||||
self.thread.join()
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
self.camera.release()
|
||||
self.camera = None
|
||||
|
||||
self.is_connected = False
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Save a few frames using `OpenCVCamera` for all cameras connected to the computer, or a selected subset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--camera-ids",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="List of camera indices used to instantiate the `OpenCVCamera`. If not provided, find and use all available camera indices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the number of frames recorded per seconds for all cameras. If not provided, use the default fps of each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--images-dir",
|
||||
type=Path,
|
||||
default="outputs/images_from_opencv_cameras",
|
||||
help="Set directory to save a few frames for each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--record-time-s",
|
||||
type=float,
|
||||
default=2.0,
|
||||
help="Set the number of seconds used to record the frames. By default, 2 seconds.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
save_images_from_cameras(**vars(args))
|
||||
58
lerobot/common/robot_devices/cameras/utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
import cv2
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
|
||||
def write_shape_on_image_inplace(image):
|
||||
height, width = image.shape[:2]
|
||||
text = f"Width: {width} Height: {height}"
|
||||
|
||||
# Define the font, scale, color, and thickness
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = 1
|
||||
color = (255, 0, 0) # Blue in BGR
|
||||
thickness = 2
|
||||
|
||||
position = (10, height - 10) # 10 pixels from the bottom-left corner
|
||||
cv2.putText(image, text, position, font, font_scale, color, thickness)
|
||||
|
||||
|
||||
def save_color_image(image, path, write_shape=False):
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if write_shape:
|
||||
write_shape_on_image_inplace(image)
|
||||
cv2.imwrite(str(path), image)
|
||||
|
||||
|
||||
def save_depth_image(depth, path, write_shape=False):
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Apply colormap on depth image (image must be converted to 8-bit per pixel first)
|
||||
depth_image = cv2.applyColorMap(cv2.convertScaleAbs(depth, alpha=0.03), cv2.COLORMAP_JET)
|
||||
|
||||
if write_shape:
|
||||
write_shape_on_image_inplace(depth_image)
|
||||
cv2.imwrite(str(path), depth_image)
|
||||
|
||||
|
||||
def convert_torch_image_to_cv2(tensor, rgb_to_bgr=True):
|
||||
assert tensor.ndim == 3
|
||||
c, h, w = tensor.shape
|
||||
assert c < h and c < w
|
||||
color_image = einops.rearrange(tensor, "c h w -> h w c").numpy()
|
||||
if rgb_to_bgr:
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||
return color_image
|
||||
|
||||
|
||||
# Defines a camera type
|
||||
class Camera(Protocol):
|
||||
def connect(self): ...
|
||||
def read(self, temporary_color: str | None = None) -> np.ndarray: ...
|
||||
def async_read(self) -> np.ndarray: ...
|
||||
def disconnect(self): ...
|
||||
816
lerobot/common/robot_devices/motors/dynamixel.py
Normal file
@@ -0,0 +1,816 @@
|
||||
import enum
|
||||
import time
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from dynamixel_sdk import (
|
||||
COMM_SUCCESS,
|
||||
DXL_HIBYTE,
|
||||
DXL_HIWORD,
|
||||
DXL_LOBYTE,
|
||||
DXL_LOWORD,
|
||||
GroupSyncRead,
|
||||
GroupSyncWrite,
|
||||
PacketHandler,
|
||||
PortHandler,
|
||||
)
|
||||
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
PROTOCOL_VERSION = 2.0
|
||||
BAUDRATE = 1_000_000
|
||||
TIMEOUT_MS = 1000
|
||||
|
||||
MAX_ID_RANGE = 252
|
||||
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m077
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m288
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xl430-w250
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xm430-w350
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xm540-w270
|
||||
|
||||
# data_name: (address, size_byte)
|
||||
X_SERIES_CONTROL_TABLE = {
|
||||
"Model_Number": (0, 2),
|
||||
"Model_Information": (2, 4),
|
||||
"Firmware_Version": (6, 1),
|
||||
"ID": (7, 1),
|
||||
"Baud_Rate": (8, 1),
|
||||
"Return_Delay_Time": (9, 1),
|
||||
"Drive_Mode": (10, 1),
|
||||
"Operating_Mode": (11, 1),
|
||||
"Secondary_ID": (12, 1),
|
||||
"Protocol_Type": (13, 1),
|
||||
"Homing_Offset": (20, 4),
|
||||
"Moving_Threshold": (24, 4),
|
||||
"Temperature_Limit": (31, 1),
|
||||
"Max_Voltage_Limit": (32, 2),
|
||||
"Min_Voltage_Limit": (34, 2),
|
||||
"PWM_Limit": (36, 2),
|
||||
"Current_Limit": (38, 2),
|
||||
"Acceleration_Limit": (40, 4),
|
||||
"Velocity_Limit": (44, 4),
|
||||
"Max_Position_Limit": (48, 4),
|
||||
"Min_Position_Limit": (52, 4),
|
||||
"Shutdown": (63, 1),
|
||||
"Torque_Enable": (64, 1),
|
||||
"LED": (65, 1),
|
||||
"Status_Return_Level": (68, 1),
|
||||
"Registered_Instruction": (69, 1),
|
||||
"Hardware_Error_Status": (70, 1),
|
||||
"Velocity_I_Gain": (76, 2),
|
||||
"Velocity_P_Gain": (78, 2),
|
||||
"Position_D_Gain": (80, 2),
|
||||
"Position_I_Gain": (82, 2),
|
||||
"Position_P_Gain": (84, 2),
|
||||
"Feedforward_2nd_Gain": (88, 2),
|
||||
"Feedforward_1st_Gain": (90, 2),
|
||||
"Bus_Watchdog": (98, 1),
|
||||
"Goal_PWM": (100, 2),
|
||||
"Goal_Current": (102, 2),
|
||||
"Goal_Velocity": (104, 4),
|
||||
"Profile_Acceleration": (108, 4),
|
||||
"Profile_Velocity": (112, 4),
|
||||
"Goal_Position": (116, 4),
|
||||
"Realtime_Tick": (120, 2),
|
||||
"Moving": (122, 1),
|
||||
"Moving_Status": (123, 1),
|
||||
"Present_PWM": (124, 2),
|
||||
"Present_Current": (126, 2),
|
||||
"Present_Velocity": (128, 4),
|
||||
"Present_Position": (132, 4),
|
||||
"Velocity_Trajectory": (136, 4),
|
||||
"Position_Trajectory": (140, 4),
|
||||
"Present_Input_Voltage": (144, 2),
|
||||
"Present_Temperature": (146, 1),
|
||||
}
|
||||
|
||||
X_SERIES_BAUDRATE_TABLE = {
|
||||
0: 9_600,
|
||||
1: 57_600,
|
||||
2: 115_200,
|
||||
3: 1_000_000,
|
||||
4: 2_000_000,
|
||||
5: 3_000_000,
|
||||
6: 4_000_000,
|
||||
}
|
||||
|
||||
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
|
||||
MODEL_CONTROL_TABLE = {
|
||||
"x_series": X_SERIES_CONTROL_TABLE,
|
||||
"xl330-m077": X_SERIES_CONTROL_TABLE,
|
||||
"xl330-m288": X_SERIES_CONTROL_TABLE,
|
||||
"xl430-w250": X_SERIES_CONTROL_TABLE,
|
||||
"xm430-w350": X_SERIES_CONTROL_TABLE,
|
||||
"xm540-w270": X_SERIES_CONTROL_TABLE,
|
||||
}
|
||||
|
||||
MODEL_RESOLUTION = {
|
||||
"x_series": 4096,
|
||||
"xl330-m077": 4096,
|
||||
"xl330-m288": 4096,
|
||||
"xl430-w250": 4096,
|
||||
"xm430-w350": 4096,
|
||||
"xm540-w270": 4096,
|
||||
}
|
||||
|
||||
MODEL_BAUDRATE_TABLE = {
|
||||
"x_series": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl330-m077": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl330-m288": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl430-w250": X_SERIES_BAUDRATE_TABLE,
|
||||
"xm430-w350": X_SERIES_BAUDRATE_TABLE,
|
||||
"xm540-w270": X_SERIES_BAUDRATE_TABLE,
|
||||
}
|
||||
|
||||
NUM_READ_RETRY = 10
|
||||
NUM_WRITE_RETRY = 10
|
||||
|
||||
|
||||
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]):
|
||||
"""This function convert the degree range to the step range for indicating motors rotation.
|
||||
It assums a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
"""
|
||||
if isinstance(degrees, float):
|
||||
degrees = np.array(degrees)
|
||||
|
||||
resolutions = [MODEL_RESOLUTION[model] for model in models]
|
||||
steps = degrees / 180 * np.array(resolutions) / 2
|
||||
steps = steps.astype(int)
|
||||
return steps
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes):
|
||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||
# already handles it for us.
|
||||
if bytes == 1:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 2:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
DXL_HIBYTE(DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 4:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
DXL_HIBYTE(DXL_LOWORD(value)),
|
||||
DXL_LOBYTE(DXL_HIWORD(value)),
|
||||
DXL_HIBYTE(DXL_HIWORD(value)),
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
f"{bytes} is provided instead."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def get_group_sync_key(data_name, motor_names):
|
||||
group_key = f"{data_name}_" + "_".join(motor_names)
|
||||
return group_key
|
||||
|
||||
|
||||
def get_result_name(fn_name, data_name, motor_names):
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
rslt_name = f"{fn_name}_{group_key}"
|
||||
return rslt_name
|
||||
|
||||
|
||||
def get_queue_name(fn_name, data_name, motor_names):
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
queue_name = f"{fn_name}_{group_key}"
|
||||
return queue_name
|
||||
|
||||
|
||||
def get_log_name(var_name, fn_name, data_name, motor_names):
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
log_name = f"{var_name}_{fn_name}_{group_key}"
|
||||
return log_name
|
||||
|
||||
|
||||
def assert_same_address(model_ctrl_table, motor_models, data_name):
|
||||
all_addr = []
|
||||
all_bytes = []
|
||||
for model in motor_models:
|
||||
addr, bytes = model_ctrl_table[model][data_name]
|
||||
all_addr.append(addr)
|
||||
all_bytes.append(bytes)
|
||||
|
||||
if len(set(all_addr)) != 1:
|
||||
raise NotImplementedError(
|
||||
f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
|
||||
)
|
||||
|
||||
if len(set(all_bytes)) != 1:
|
||||
raise NotImplementedError(
|
||||
f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
|
||||
)
|
||||
|
||||
|
||||
def find_available_ports():
|
||||
ports = []
|
||||
for path in Path("/dev").glob("tty*"):
|
||||
ports.append(str(path))
|
||||
return ports
|
||||
|
||||
|
||||
def find_port():
|
||||
print("Finding all available ports for the DynamixelMotorsBus.")
|
||||
ports_before = find_available_ports()
|
||||
print(ports_before)
|
||||
|
||||
print("Remove the usb cable from your DynamixelMotorsBus and press Enter when done.")
|
||||
input()
|
||||
|
||||
time.sleep(0.5)
|
||||
ports_after = find_available_ports()
|
||||
ports_diff = list(set(ports_before) - set(ports_after))
|
||||
|
||||
if len(ports_diff) == 1:
|
||||
port = ports_diff[0]
|
||||
print(f"The port of this DynamixelMotorsBus is '{port}'")
|
||||
print("Reconnect the usb cable.")
|
||||
elif len(ports_diff) == 0:
|
||||
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
|
||||
else:
|
||||
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
|
||||
|
||||
|
||||
class TorqueMode(enum.Enum):
|
||||
ENABLED = 1
|
||||
DISABLED = 0
|
||||
|
||||
|
||||
class OperatingMode(enum.Enum):
|
||||
VELOCITY = 1
|
||||
POSITION = 3
|
||||
EXTENDED_POSITION = 4
|
||||
CURRENT_CONTROLLED_POSITION = 5
|
||||
PWM = 16
|
||||
UNKNOWN = -1
|
||||
|
||||
|
||||
class DriveMode(enum.Enum):
|
||||
NON_INVERTED = 0
|
||||
INVERTED = 1
|
||||
|
||||
|
||||
class DynamixelMotorsBus:
|
||||
# TODO(rcadene): Add a script to find the motor indices without DynamixelWizzard2
|
||||
"""
|
||||
The DynamixelMotorsBus class allows to efficiently read and write to the attached motors. It relies on
|
||||
the python dynamixel sdk to communicate with the motors. For more info, see the [Dynamixel SDK Documentation](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
|
||||
|
||||
A DynamixelMotorsBus instance requires a port (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
||||
To find the port, you can run our utility script:
|
||||
```bash
|
||||
python lerobot/common/robot_devices/motors/dynamixel.py
|
||||
>>> Finding all available ports for the DynamixelMotorsBus.
|
||||
>>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
>>> Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
>>> The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751.
|
||||
>>> Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example of usage for 1 motor connected to the bus:
|
||||
```python
|
||||
motor_name = "gripper"
|
||||
motor_index = 6
|
||||
motor_model = "xl330-m288"
|
||||
|
||||
motors_bus = DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={motor_name: (motor_index, motor_model)},
|
||||
)
|
||||
motors_bus.connect()
|
||||
|
||||
position = motors_bus.read("Present_Position")
|
||||
|
||||
# move from a few motor steps as an example
|
||||
few_steps = 30
|
||||
motors_bus.write("Goal_Position", position + few_steps)
|
||||
|
||||
# when done, consider disconnecting
|
||||
motors_bus.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, tuple[int, str]],
|
||||
extra_model_control_table: dict[str, list[tuple]] | None = None,
|
||||
extra_model_resolution: dict[str, int] | None = None,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
|
||||
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
||||
if extra_model_control_table:
|
||||
self.model_ctrl_table.update(extra_model_control_table)
|
||||
|
||||
self.model_resolution = deepcopy(MODEL_RESOLUTION)
|
||||
if extra_model_resolution:
|
||||
self.model_resolution.update(extra_model_resolution)
|
||||
|
||||
self.port_handler = None
|
||||
self.packet_handler = None
|
||||
self.calibration = None
|
||||
self.is_connected = False
|
||||
self.group_readers = {}
|
||||
self.group_writers = {}
|
||||
self.logs = {}
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice."
|
||||
)
|
||||
|
||||
self.port_handler = PortHandler(self.port)
|
||||
self.packet_handler = PacketHandler(PROTOCOL_VERSION)
|
||||
|
||||
try:
|
||||
if not self.port_handler.openPort():
|
||||
raise OSError(f"Failed to open port '{self.port}'.")
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
print(
|
||||
"\nTry running `python lerobot/common/robot_devices/motors/dynamixel.py` to make sure you are using the correct port.\n"
|
||||
)
|
||||
raise
|
||||
|
||||
# Allow to read and write
|
||||
self.is_connected = True
|
||||
|
||||
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
|
||||
|
||||
# Set expected baudrate for the bus
|
||||
self.set_bus_baudrate(BAUDRATE)
|
||||
|
||||
if not self.are_motors_configured():
|
||||
input(
|
||||
"\n/!\\ A configuration issue has been detected with your motors: \n"
|
||||
"If it's the first time that you use these motors, press enter to configure your motors... but before "
|
||||
"verify that all the cables are connected the proper way. If you find an issue, before making a modification, "
|
||||
"kill the python process, unplug the power cord to not damage the motors, rewire correctly, then plug the power "
|
||||
"again and relaunch the script.\n"
|
||||
)
|
||||
print()
|
||||
self.configure_motors()
|
||||
|
||||
def reconnect(self):
|
||||
self.port_handler = PortHandler(self.port)
|
||||
self.packet_handler = PacketHandler(PROTOCOL_VERSION)
|
||||
if not self.port_handler.openPort():
|
||||
raise OSError(f"Failed to open port '{self.port}'.")
|
||||
self.is_connected = True
|
||||
|
||||
def are_motors_configured(self):
|
||||
# Only check the motor indices and not baudrate, since if the motor baudrates are incorrect,
|
||||
# a ConnectionError will be raised anyway.
|
||||
try:
|
||||
return (self.motor_indices == self.read("ID")).all()
|
||||
except ConnectionError as e:
|
||||
print(e)
|
||||
return False
|
||||
|
||||
def configure_motors(self):
|
||||
# TODO(rcadene): This script assumes motors follow the X_SERIES baudrates
|
||||
# TODO(rcadene): Refactor this function with intermediate high-level functions
|
||||
|
||||
print("Scanning all baudrates and motor indices")
|
||||
all_baudrates = set(X_SERIES_BAUDRATE_TABLE.values())
|
||||
ids_per_baudrate = {}
|
||||
for baudrate in all_baudrates:
|
||||
self.set_bus_baudrate(baudrate)
|
||||
present_ids = self.find_motor_indices()
|
||||
if len(present_ids) > 0:
|
||||
ids_per_baudrate[baudrate] = present_ids
|
||||
print(f"Motor indices detected: {ids_per_baudrate}")
|
||||
print()
|
||||
|
||||
possible_baudrates = list(ids_per_baudrate.keys())
|
||||
possible_ids = list({idx for sublist in ids_per_baudrate.values() for idx in sublist})
|
||||
untaken_ids = list(set(range(MAX_ID_RANGE)) - set(possible_ids) - set(self.motor_indices))
|
||||
|
||||
# Connect successively one motor to the chain and write a unique random index for each
|
||||
for i in range(len(self.motors)):
|
||||
self.disconnect()
|
||||
input(
|
||||
"1. Unplug the power cord\n"
|
||||
"2. Plug/unplug minimal number of cables to only have the first "
|
||||
f"{i+1} motor(s) ({self.motor_names[:i+1]}) connected.\n"
|
||||
"3. Re-plug the power cord\n"
|
||||
"Press Enter to continue..."
|
||||
)
|
||||
print()
|
||||
self.reconnect()
|
||||
|
||||
if i > 0:
|
||||
try:
|
||||
self._read_with_motor_ids(self.motor_models, untaken_ids[:i], "ID")
|
||||
except ConnectionError:
|
||||
print(f"Failed to read from {untaken_ids[:i+1]}. Make sure the power cord is plugged in.")
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
self.reconnect()
|
||||
|
||||
print("Scanning possible baudrates and motor indices")
|
||||
motor_found = False
|
||||
for baudrate in possible_baudrates:
|
||||
self.set_bus_baudrate(baudrate)
|
||||
present_ids = self.find_motor_indices(possible_ids)
|
||||
if len(present_ids) == 1:
|
||||
present_idx = present_ids[0]
|
||||
print(f"Detected motor with index {present_idx}")
|
||||
|
||||
if baudrate != BAUDRATE:
|
||||
print(f"Setting its baudrate to {BAUDRATE}")
|
||||
baudrate_idx = list(X_SERIES_BAUDRATE_TABLE.values()).index(BAUDRATE)
|
||||
|
||||
# The write can fail, so we allow retries
|
||||
for _ in range(NUM_WRITE_RETRY):
|
||||
self._write_with_motor_ids(
|
||||
self.motor_models, present_idx, "Baud_Rate", baudrate_idx
|
||||
)
|
||||
time.sleep(0.5)
|
||||
self.set_bus_baudrate(BAUDRATE)
|
||||
try:
|
||||
present_baudrate_idx = self._read_with_motor_ids(
|
||||
self.motor_models, present_idx, "Baud_Rate"
|
||||
)
|
||||
except ConnectionError:
|
||||
print("Failed to write baudrate. Retrying.")
|
||||
self.set_bus_baudrate(baudrate)
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise
|
||||
|
||||
if present_baudrate_idx != baudrate_idx:
|
||||
raise OSError("Failed to write baudrate.")
|
||||
|
||||
print(f"Setting its index to a temporary untaken index ({untaken_ids[i]})")
|
||||
self._write_with_motor_ids(self.motor_models, present_idx, "ID", untaken_ids[i])
|
||||
|
||||
present_idx = self._read_with_motor_ids(self.motor_models, untaken_ids[i], "ID")
|
||||
if present_idx != untaken_ids[i]:
|
||||
raise OSError("Failed to write index.")
|
||||
|
||||
motor_found = True
|
||||
break
|
||||
elif len(present_ids) > 1:
|
||||
raise OSError(f"More than one motor detected ({present_ids}), but only one was expected.")
|
||||
|
||||
if not motor_found:
|
||||
raise OSError(
|
||||
"No motor found, but one new motor expected. Verify power cord is plugged in and retry."
|
||||
)
|
||||
print()
|
||||
|
||||
print(f"Setting expected motor indices: {self.motor_indices}")
|
||||
self.set_bus_baudrate(BAUDRATE)
|
||||
self._write_with_motor_ids(
|
||||
self.motor_models, untaken_ids[: len(self.motors)], "ID", self.motor_indices
|
||||
)
|
||||
print()
|
||||
|
||||
if (self.read("ID") != self.motor_indices).any():
|
||||
raise OSError("Failed to write motors indices.")
|
||||
|
||||
print("Configuration is done!")
|
||||
|
||||
def find_motor_indices(self, possible_ids=None):
|
||||
if possible_ids is None:
|
||||
possible_ids = range(MAX_ID_RANGE)
|
||||
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self._read_with_motor_ids(self.motor_models, [idx], "ID")[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
if idx != present_idx:
|
||||
# sanity check
|
||||
raise OSError(
|
||||
"Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
|
||||
)
|
||||
indices.append(idx)
|
||||
|
||||
return indices
|
||||
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
raise OSError("Failed to write bus baud rate.")
|
||||
|
||||
@property
|
||||
def motor_names(self) -> list[str]:
|
||||
return list(self.motors.keys())
|
||||
|
||||
@property
|
||||
def motor_models(self) -> list[str]:
|
||||
return [model for _, model in self.motors.values()]
|
||||
|
||||
@property
|
||||
def motor_indices(self) -> list[int]:
|
||||
return [idx for idx, _ in self.motors.values()]
|
||||
|
||||
def set_calibration(self, calibration: dict[str, tuple[int, bool]]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
||||
a "zero position" at 0 degree.
|
||||
|
||||
Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor
|
||||
rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range.
|
||||
|
||||
Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation
|
||||
when given a goal position that is + or - their resolution. For instance, dynamixel xl330-m077 have a resolution of 4096, and
|
||||
at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830,
|
||||
or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor.
|
||||
To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work
|
||||
in the centered nominal degree range ]-180, 180[.
|
||||
"""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
# Convert from unsigned int32 original range [0, 2**32[ to centered signed int32 range [-2**31, 2**31[
|
||||
values = values.astype(np.int32)
|
||||
|
||||
for i, name in enumerate(motor_names):
|
||||
homing_offset, drive_mode = self.calibration[name]
|
||||
|
||||
# Update direction of rotation of the motor to match between leader and follower. In fact, the motor of the leader for a given joint
|
||||
# can be assembled in an opposite direction in term of rotation than the motor of the follower on the same joint.
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from range [-2**31, 2**31[ to nominal range ]-resolution, resolution[ (e.g. ]-2048, 2048[)
|
||||
values[i] += homing_offset
|
||||
|
||||
# Convert from range ]-resolution, resolution[ to the universal float32 centered degree range ]-180, 180[
|
||||
values = values.astype(np.float32)
|
||||
for i, name in enumerate(motor_names):
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
values[i] = values[i] / (resolution // 2) * 180
|
||||
|
||||
return values
|
||||
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Inverse of `apply_calibration`."""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
# Convert from the universal float32 centered degree range ]-180, 180[ to resolution range ]-resolution, resolution[
|
||||
for i, name in enumerate(motor_names):
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
values[i] = values[i] / 180 * (resolution // 2)
|
||||
|
||||
values = np.round(values).astype(np.int32)
|
||||
|
||||
# Convert from nominal range ]-resolution, resolution[ to centered signed int32 range [-2**31, 2**31[
|
||||
for i, name in enumerate(motor_names):
|
||||
homing_offset, drive_mode = self.calibration[name]
|
||||
values[i] -= homing_offset
|
||||
|
||||
# Update direction of rotation of the motor that was matching between leader and follower to their original direction.
|
||||
# In fact, the motor of the leader for a given joint can be assembled in an opposite direction in term of rotation
|
||||
# than the motor of the follower on the same joint.
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
return values
|
||||
|
||||
def _read_with_motor_ids(self, motor_models, motor_ids, data_name):
|
||||
return_list = True
|
||||
if not isinstance(motor_ids, list):
|
||||
return_list = False
|
||||
motor_ids = [motor_ids]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
||||
for idx in motor_ids:
|
||||
group.addParam(idx)
|
||||
|
||||
comm = group.txRxPacket()
|
||||
if comm != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = group.getData(idx, addr, bytes)
|
||||
values.append(value)
|
||||
|
||||
if return_list:
|
||||
return values
|
||||
else:
|
||||
return values[0]
|
||||
|
||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
if isinstance(motor_names, str):
|
||||
motor_names = [motor_names]
|
||||
|
||||
motor_ids = []
|
||||
models = []
|
||||
for name in motor_names:
|
||||
motor_idx, model = self.motors[name]
|
||||
motor_ids.append(motor_idx)
|
||||
models.append(model)
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
if data_name not in self.group_readers:
|
||||
# create new group reader
|
||||
self.group_readers[group_key] = GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
||||
for idx in motor_ids:
|
||||
self.group_readers[group_key].addParam(idx)
|
||||
|
||||
for _ in range(NUM_READ_RETRY):
|
||||
comm = self.group_readers[group_key].txRxPacket()
|
||||
if comm == COMM_SUCCESS:
|
||||
break
|
||||
|
||||
if comm != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Read failed due to communication error on port {self.port} for group_key {group_key}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = self.group_readers[group_key].getData(idx, addr, bytes)
|
||||
values.append(value)
|
||||
|
||||
values = np.array(values)
|
||||
|
||||
# Convert to signed int to use range [-2048, 2048] for our motor positions.
|
||||
if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
|
||||
values = values.astype(np.int32)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
|
||||
# We expect our motors to stay in a nominal range of [-180, 180] degrees
|
||||
# which corresponds to a half turn rotation.
|
||||
# However, some motors can turn a bit more, hence we extend the nominal range to [-270, 270]
|
||||
# which is less than a full 360 degree rotation.
|
||||
if not np.all((values > -270) & (values < 270)):
|
||||
raise ValueError(
|
||||
f"Wrong motor position range detected. "
|
||||
f"Expected to be in [-270, +270] but in [{values.min()}, {values.max()}]. "
|
||||
"This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. "
|
||||
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
|
||||
)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the data was received
|
||||
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names)
|
||||
self.logs[ts_utc_name] = capture_timestamp_utc()
|
||||
|
||||
return values
|
||||
|
||||
def _write_with_motor_ids(self, motor_models, motor_ids, data_name, values):
|
||||
if not isinstance(motor_ids, list):
|
||||
motor_ids = [motor_ids]
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes)
|
||||
group.addParam(idx, data)
|
||||
|
||||
comm = group.txPacket()
|
||||
if comm != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
if isinstance(motor_names, str):
|
||||
motor_names = [motor_names]
|
||||
|
||||
if isinstance(values, (int, float, np.integer)):
|
||||
values = [int(values)] * len(motor_names)
|
||||
|
||||
values = np.array(values)
|
||||
|
||||
motor_ids = []
|
||||
models = []
|
||||
for name in motor_names:
|
||||
motor_idx, model = self.motors[name]
|
||||
motor_ids.append(motor_idx)
|
||||
models.append(model)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
||||
values = self.revert_calibration(values, motor_names)
|
||||
|
||||
values = values.tolist()
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
init_group = data_name not in self.group_readers
|
||||
if init_group:
|
||||
self.group_writers[group_key] = GroupSyncWrite(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
)
|
||||
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes)
|
||||
if init_group:
|
||||
self.group_writers[group_key].addParam(idx, data)
|
||||
else:
|
||||
self.group_writers[group_key].changeParam(idx, data)
|
||||
|
||||
comm = self.group_writers[group_key].txPacket()
|
||||
if comm != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Write failed due to communication error on port {self.port} for group_key {group_key}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
# log the number of seconds it took to write the data to the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# TODO(rcadene): should we log the time before sending the write command?
|
||||
# log the utc time when the write has been completed
|
||||
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
|
||||
self.logs[ts_utc_name] = capture_timestamp_utc()
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first."
|
||||
)
|
||||
|
||||
if self.port_handler is not None:
|
||||
self.port_handler.closePort()
|
||||
self.port_handler = None
|
||||
|
||||
self.packet_handler = None
|
||||
self.group_readers = {}
|
||||
self.group_writers = {}
|
||||
self.is_connected = False
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Helper to find the usb port associated to all your DynamixelMotorsBus.
|
||||
find_port()
|
||||
10
lerobot/common/robot_devices/motors/utils.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class MotorsBus(Protocol):
|
||||
def motor_names(self): ...
|
||||
def set_calibration(self): ...
|
||||
def apply_calibration(self): ...
|
||||
def revert_calibration(self): ...
|
||||
def read(self): ...
|
||||
def write(self): ...
|
||||
7
lerobot/common/robot_devices/robots/factory.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import hydra
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def make_robot(cfg: DictConfig):
|
||||
robot = hydra.utils.instantiate(cfg)
|
||||
return robot
|
||||
515
lerobot/common/robot_devices/robots/koch.py
Normal file
@@ -0,0 +1,515 @@
|
||||
import pickle
|
||||
import time
|
||||
from dataclasses import dataclass, field, replace
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.motors.dynamixel import (
|
||||
OperatingMode,
|
||||
TorqueMode,
|
||||
convert_degrees_to_steps,
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
|
||||
########################################################################
|
||||
# Calibration logic
|
||||
########################################################################
|
||||
|
||||
URL_TEMPLATE = (
|
||||
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
)
|
||||
|
||||
# In nominal degree range ]-180, +180[
|
||||
ZERO_POSITION_DEGREE = 0
|
||||
ROTATED_POSITION_DEGREE = 90
|
||||
GRIPPER_OPEN_DEGREE = 35.156
|
||||
|
||||
|
||||
def assert_drive_mode(drive_mode):
|
||||
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
||||
if not np.all(np.isin(drive_mode, [0, 1])):
|
||||
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
|
||||
|
||||
|
||||
def apply_drive_mode(position, drive_mode):
|
||||
assert_drive_mode(drive_mode)
|
||||
# Convert `drive_mode` from [0, 1] with 0 indicates original rotation direction and 1 inverted,
|
||||
# to [-1, 1] with 1 indicates original rotation direction and -1 inverted.
|
||||
signed_drive_mode = -(drive_mode * 2 - 1)
|
||||
position *= signed_drive_mode
|
||||
return position
|
||||
|
||||
|
||||
def reset_torque_mode(arm: MotorsBus):
|
||||
# To be configured, all servos must be in "torque disable" mode
|
||||
arm.write("Torque_Enable", TorqueMode.DISABLED.value)
|
||||
|
||||
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
|
||||
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
|
||||
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
|
||||
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
|
||||
if len(all_motors_except_gripper) > 0:
|
||||
arm.write("Operating_Mode", OperatingMode.EXTENDED_POSITION.value, all_motors_except_gripper)
|
||||
|
||||
# Use 'position control current based' for gripper to be limited by the limit of the current.
|
||||
# For the follower gripper, it means it can grasp an object without forcing too much even tho,
|
||||
# it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
|
||||
# For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
|
||||
# to make it move, and it will move back to its original target position when we release the force.
|
||||
arm.write("Operating_Mode", OperatingMode.CURRENT_CONTROLLED_POSITION.value, "gripper")
|
||||
|
||||
|
||||
def run_arm_calibration(arm: MotorsBus, name: str, arm_type: str):
|
||||
"""This function ensures that a neural network trained on data collected on a given robot
|
||||
can work on another robot. For instance before calibration, setting a same goal position
|
||||
for each motor of two different robots will get two very different positions. But after calibration,
|
||||
the two robots will move to the same position.To this end, this function computes the homing offset
|
||||
and the drive mode for each motor of a given robot.
|
||||
|
||||
Homing offset is used to shift the motor position to a ]-2048, +2048[ nominal range (when the motor uses 2048 steps
|
||||
to complete a half a turn). This range is set around an arbitrary "zero position" corresponding to all motor positions
|
||||
being 0. During the calibration process, you will need to manually move the robot to this "zero position".
|
||||
|
||||
Drive mode is used to invert the rotation direction of the motor. This is useful when some motors have been assembled
|
||||
in the opposite orientation for some robots. During the calibration process, you will need to manually move the robot
|
||||
to the "rotated position".
|
||||
|
||||
After calibration, the homing offsets and drive modes are stored in a cache.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
run_arm_calibration(arm, "left", "follower")
|
||||
```
|
||||
"""
|
||||
reset_torque_mode(arm)
|
||||
|
||||
print(f"\nRunning calibration of {name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to zero position")
|
||||
print("See: " + URL_TEMPLATE.format(robot="koch", arm=arm_type, position="zero"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# We arbitrarely choosed our zero target position to be a straight horizontal position with gripper upwards and closed.
|
||||
# It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will
|
||||
# corresponds to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position.
|
||||
zero_position = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models)
|
||||
|
||||
def _compute_nearest_rounded_position(position, models):
|
||||
# TODO(rcadene): Rework this function since some motors cant physically rotate a quarter turn
|
||||
# (e.g. the gripper of Aloha arms can only rotate ~50 degree)
|
||||
quarter_turn_degree = 90
|
||||
quarter_turn = convert_degrees_to_steps(quarter_turn_degree, models)
|
||||
nearest_pos = np.round(position.astype(float) / quarter_turn) * quarter_turn
|
||||
return nearest_pos.astype(position.dtype)
|
||||
|
||||
# Compute homing offset so that `present_position + homing_offset ~= target_position`.
|
||||
position = arm.read("Present_Position")
|
||||
position = _compute_nearest_rounded_position(position, arm.motor_models)
|
||||
homing_offset = zero_position - position
|
||||
|
||||
print("\nMove arm to rotated target position")
|
||||
print("See: " + URL_TEMPLATE.format(robot="koch", arm=arm_type, position="rotated"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# The rotated target position corresponds to a rotation of a quarter turn from the zero position.
|
||||
# This allows to identify the rotation direction of each motor.
|
||||
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
|
||||
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
|
||||
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
|
||||
# of the previous motor in the kinetic chain.
|
||||
rotated_position = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
||||
|
||||
# Find drive mode by rotating each motor by a quarter of a turn.
|
||||
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
||||
position = arm.read("Present_Position")
|
||||
position += homing_offset
|
||||
position = _compute_nearest_rounded_position(position, arm.motor_models)
|
||||
drive_mode = (position != rotated_position).astype(np.int32)
|
||||
|
||||
# Re-compute homing offset to take into account drive mode
|
||||
position = arm.read("Present_Position")
|
||||
position = apply_drive_mode(position, drive_mode)
|
||||
position = _compute_nearest_rounded_position(position, arm.motor_models)
|
||||
homing_offset = rotated_position - position
|
||||
|
||||
print("\nMove arm to rest position")
|
||||
print("See: " + URL_TEMPLATE.format(robot="koch", arm=arm_type, position="rest"))
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
|
||||
return homing_offset, drive_mode
|
||||
|
||||
|
||||
########################################################################
|
||||
# Alexander Koch robot arm
|
||||
########################################################################
|
||||
|
||||
|
||||
@dataclass
|
||||
class KochRobotConfig:
|
||||
"""
|
||||
Example of usage:
|
||||
```python
|
||||
KochRobotConfig()
|
||||
```
|
||||
"""
|
||||
|
||||
# Define all components of the robot
|
||||
leader_arms: dict[str, MotorsBus] = field(default_factory=lambda: {})
|
||||
follower_arms: dict[str, MotorsBus] = field(default_factory=lambda: {})
|
||||
cameras: dict[str, Camera] = field(default_factory=lambda: {})
|
||||
|
||||
|
||||
class KochRobot:
|
||||
# TODO(rcadene): Implement force feedback
|
||||
"""This class allows to control any Koch robot of various number of motors.
|
||||
|
||||
A few versions are available:
|
||||
- [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, which was developed
|
||||
by Alexander Koch from [Tau Robotics](https://tau-robotics.com): [Github for sourcing and assembly](
|
||||
- [Koch v1.1])https://github.com/jess-moss/koch-v1-1), which was developed by Jess Moss.
|
||||
|
||||
Example of highest frequency teleoperation without camera:
|
||||
```python
|
||||
# Defines how to communicate with the motors of the leader and follower arms
|
||||
leader_arms = {
|
||||
"main": DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl330-m077"),
|
||||
"shoulder_lift": (2, "xl330-m077"),
|
||||
"elbow_flex": (3, "xl330-m077"),
|
||||
"wrist_flex": (4, "xl330-m077"),
|
||||
"wrist_roll": (5, "xl330-m077"),
|
||||
"gripper": (6, "xl330-m077"),
|
||||
},
|
||||
),
|
||||
}
|
||||
follower_arms = {
|
||||
"main": DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0032081",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl430-w250"),
|
||||
"shoulder_lift": (2, "xl430-w250"),
|
||||
"elbow_flex": (3, "xl330-m288"),
|
||||
"wrist_flex": (4, "xl330-m288"),
|
||||
"wrist_roll": (5, "xl330-m288"),
|
||||
"gripper": (6, "xl330-m288"),
|
||||
},
|
||||
),
|
||||
}
|
||||
robot = KochRobot(leader_arms, follower_arms)
|
||||
|
||||
# Connect motors buses and cameras if any (Required)
|
||||
robot.connect()
|
||||
|
||||
while True:
|
||||
robot.teleop_step()
|
||||
```
|
||||
|
||||
Example of highest frequency data collection without camera:
|
||||
```python
|
||||
# Assumes leader and follower arms have been instantiated already (see first example)
|
||||
robot = KochRobot(leader_arms, follower_arms)
|
||||
robot.connect()
|
||||
while True:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
```
|
||||
|
||||
Example of highest frequency data collection with cameras:
|
||||
```python
|
||||
# Defines how to communicate with 2 cameras connected to the computer.
|
||||
# Here, the webcam of the laptop and the phone (connected in USB to the laptop)
|
||||
# can be reached respectively using the camera indices 0 and 1. These indices can be
|
||||
# arbitrary. See the documentation of `OpenCVCamera` to find your own camera indices.
|
||||
cameras = {
|
||||
"laptop": OpenCVCamera(camera_index=0, fps=30, width=640, height=480),
|
||||
"phone": OpenCVCamera(camera_index=1, fps=30, width=640, height=480),
|
||||
}
|
||||
|
||||
# Assumes leader and follower arms have been instantiated already (see first example)
|
||||
robot = KochRobot(leader_arms, follower_arms, cameras)
|
||||
robot.connect()
|
||||
while True:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
```
|
||||
|
||||
Example of controlling the robot with a policy (without running multiple policies in parallel to ensure highest frequency):
|
||||
```python
|
||||
# Assumes leader and follower arms + cameras have been instantiated already (see previous example)
|
||||
robot = KochRobot(leader_arms, follower_arms, cameras)
|
||||
robot.connect()
|
||||
while True:
|
||||
# Uses the follower arms and cameras to capture an observation
|
||||
observation = robot.capture_observation()
|
||||
|
||||
# Assumes a policy has been instantiated
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# Orders the robot to move
|
||||
robot.send_action(action)
|
||||
```
|
||||
|
||||
Example of disconnecting which is not mandatory since we disconnect when the object is deleted:
|
||||
```python
|
||||
robot.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: KochRobotConfig | None = None,
|
||||
calibration_path: Path = ".cache/calibration/koch.pkl",
|
||||
**kwargs,
|
||||
):
|
||||
if config is None:
|
||||
config = KochRobotConfig()
|
||||
# Overwrite config arguments using kwargs
|
||||
self.config = replace(config, **kwargs)
|
||||
self.calibration_path = Path(calibration_path)
|
||||
|
||||
self.leader_arms = self.config.leader_arms
|
||||
self.follower_arms = self.config.follower_arms
|
||||
self.cameras = self.config.cameras
|
||||
self.is_connected = False
|
||||
self.logs = {}
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
"KochRobot is already connected. Do not run `robot.connect()` twice."
|
||||
)
|
||||
|
||||
if not self.leader_arms and not self.follower_arms and not self.cameras:
|
||||
raise ValueError(
|
||||
"KochRobot doesn't have any device to connect. See example of usage in docstring of the class."
|
||||
)
|
||||
|
||||
# Connect the arms
|
||||
for name in self.follower_arms:
|
||||
print(f"Connecting {name} follower arm.")
|
||||
self.follower_arms[name].connect()
|
||||
print(f"Connecting {name} leader arm.")
|
||||
self.leader_arms[name].connect()
|
||||
|
||||
# Reset the arms and load or run calibration
|
||||
if self.calibration_path.exists():
|
||||
# Reset all arms before setting calibration
|
||||
for name in self.follower_arms:
|
||||
reset_torque_mode(self.follower_arms[name])
|
||||
for name in self.leader_arms:
|
||||
reset_torque_mode(self.leader_arms[name])
|
||||
|
||||
with open(self.calibration_path, "rb") as f:
|
||||
calibration = pickle.load(f)
|
||||
else:
|
||||
print(f"Missing calibration file '{self.calibration_path}'. Starting calibration precedure.")
|
||||
# Run calibration process which begins by reseting all arms
|
||||
calibration = self.run_calibration()
|
||||
|
||||
print(f"Calibration is done! Saving calibration file '{self.calibration_path}'")
|
||||
self.calibration_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.calibration_path, "wb") as f:
|
||||
pickle.dump(calibration, f)
|
||||
|
||||
# Set calibration
|
||||
for name in self.follower_arms:
|
||||
self.follower_arms[name].set_calibration(calibration[f"follower_{name}"])
|
||||
for name in self.leader_arms:
|
||||
self.leader_arms[name].set_calibration(calibration[f"leader_{name}"])
|
||||
|
||||
# Set better PID values to close the gap between recored states and actions
|
||||
# TODO(rcadene): Implement an automatic procedure to set optimial PID values for each motor
|
||||
for name in self.follower_arms:
|
||||
self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex")
|
||||
self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex")
|
||||
self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex")
|
||||
|
||||
# Enable torque on all motors of the follower arms
|
||||
for name in self.follower_arms:
|
||||
print(f"Activating torque on {name} follower arm.")
|
||||
self.follower_arms[name].write("Torque_Enable", 1)
|
||||
|
||||
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
|
||||
# so that we can use it as a trigger to close the gripper of the follower arms.
|
||||
for name in self.leader_arms:
|
||||
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
|
||||
self.leader_arms[name].write("Goal_Position", GRIPPER_OPEN_DEGREE, "gripper")
|
||||
|
||||
# Connect the cameras
|
||||
for name in self.cameras:
|
||||
self.cameras[name].connect()
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def run_calibration(self):
|
||||
calibration = {}
|
||||
|
||||
for name in self.follower_arms:
|
||||
homing_offset, drive_mode = run_arm_calibration(self.follower_arms[name], name, "follower")
|
||||
|
||||
calibration[f"follower_{name}"] = {}
|
||||
for idx, motor_name in enumerate(self.follower_arms[name].motor_names):
|
||||
calibration[f"follower_{name}"][motor_name] = (homing_offset[idx], drive_mode[idx])
|
||||
|
||||
for name in self.leader_arms:
|
||||
homing_offset, drive_mode = run_arm_calibration(self.leader_arms[name], name, "leader")
|
||||
|
||||
calibration[f"leader_{name}"] = {}
|
||||
for idx, motor_name in enumerate(self.leader_arms[name].motor_names):
|
||||
calibration[f"leader_{name}"][motor_name] = (homing_offset[idx], drive_mode[idx])
|
||||
|
||||
return calibration
|
||||
|
||||
def teleop_step(
|
||||
self, record_data=False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"KochRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
# Prepare to assign the position of the leader to the follower
|
||||
leader_pos = {}
|
||||
for name in self.leader_arms:
|
||||
before_lread_t = time.perf_counter()
|
||||
leader_pos[name] = self.leader_arms[name].read("Present_Position")
|
||||
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
|
||||
|
||||
follower_goal_pos = {}
|
||||
for name in self.leader_arms:
|
||||
follower_goal_pos[name] = leader_pos[name]
|
||||
|
||||
# Send action
|
||||
for name in self.follower_arms:
|
||||
before_fwrite_t = time.perf_counter()
|
||||
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name])
|
||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
|
||||
|
||||
# Early exit when recording data is not requested
|
||||
if not record_data:
|
||||
return
|
||||
|
||||
# TODO(rcadene): Add velocity and other info
|
||||
# Read follower position
|
||||
follower_pos = {}
|
||||
for name in self.follower_arms:
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
for name in self.follower_arms:
|
||||
if name in follower_pos:
|
||||
state.append(follower_pos[name])
|
||||
state = np.concatenate(state)
|
||||
|
||||
# Create action by concatenating follower goal position
|
||||
action = []
|
||||
for name in self.follower_arms:
|
||||
if name in follower_goal_pos:
|
||||
action.append(follower_goal_pos[name])
|
||||
action = np.concatenate(action)
|
||||
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries and format to pytorch
|
||||
obs_dict, action_dict = {}, {}
|
||||
obs_dict["observation.state"] = torch.from_numpy(state)
|
||||
action_dict["action"] = torch.from_numpy(action)
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
|
||||
|
||||
return obs_dict, action_dict
|
||||
|
||||
def capture_observation(self):
|
||||
"""The returned observations do not have a batch dimension."""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"KochRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
# Read follower position
|
||||
follower_pos = {}
|
||||
for name in self.follower_arms:
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
for name in self.follower_arms:
|
||||
if name in follower_pos:
|
||||
state.append(follower_pos[name])
|
||||
state = np.concatenate(state)
|
||||
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries and format to pytorch
|
||||
obs_dict = {}
|
||||
obs_dict["observation.state"] = torch.from_numpy(state)
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: torch.Tensor):
|
||||
"""The provided action is expected to be a vector."""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"KochRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
from_idx = 0
|
||||
to_idx = 0
|
||||
follower_goal_pos = {}
|
||||
for name in self.follower_arms:
|
||||
if name in self.follower_arms:
|
||||
to_idx += len(self.follower_arms[name].motor_names)
|
||||
follower_goal_pos[name] = action[from_idx:to_idx].numpy()
|
||||
from_idx = to_idx
|
||||
|
||||
for name in self.follower_arms:
|
||||
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name].astype(np.int32))
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"KochRobot is not connected. You need to run `robot.connect()` before disconnecting."
|
||||
)
|
||||
|
||||
for name in self.follower_arms:
|
||||
self.follower_arms[name].disconnect()
|
||||
|
||||
for name in self.leader_arms:
|
||||
self.leader_arms[name].disconnect()
|
||||
|
||||
for name in self.cameras:
|
||||
self.cameras[name].disconnect()
|
||||
|
||||
self.is_connected = False
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
||||
9
lerobot/common/robot_devices/robots/utils.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class Robot(Protocol):
|
||||
def init_teleop(self): ...
|
||||
def run_calibration(self): ...
|
||||
def teleop_step(self, record_data=False): ...
|
||||
def capture_observation(self): ...
|
||||
def send_action(self, action): ...
|
||||
19
lerobot/common/robot_devices/utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
class RobotDeviceNotConnectedError(Exception):
|
||||
"""Exception raised when the robot device is not connected."""
|
||||
|
||||
def __init__(
|
||||
self, message="This robot device is not connected. Try calling `robot_device.connect()` first."
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class RobotDeviceAlreadyConnectedError(Exception):
|
||||
"""Exception raised when the robot device is already connected."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message="This robot device is already connected. Try not calling `robot_device.connect()` twice.",
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
@@ -17,7 +17,7 @@ import logging
|
||||
import os.path as osp
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
@@ -158,6 +158,7 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
|
||||
version_base="1.2",
|
||||
)
|
||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
@@ -172,3 +173,7 @@ def print_cuda_memory_usage():
|
||||
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
|
||||
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
|
||||
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
|
||||
|
||||
|
||||
def capture_timestamp_utc():
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
@@ -32,19 +32,54 @@ video_backend: pyav
|
||||
|
||||
training:
|
||||
offline_steps: ???
|
||||
# NOTE: `online_steps` is not implemented yet. It's here as a placeholder.
|
||||
online_steps: ???
|
||||
online_steps_between_rollouts: ???
|
||||
online_sampling_ratio: 0.5
|
||||
# `online_env_seed` is used for environments for online training data rollouts.
|
||||
online_env_seed: ???
|
||||
|
||||
# Number of workers for the offline training dataloader.
|
||||
num_workers: 4
|
||||
|
||||
batch_size: ???
|
||||
|
||||
eval_freq: ???
|
||||
log_freq: 200
|
||||
save_checkpoint: true
|
||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||
save_freq: ???
|
||||
num_workers: 4
|
||||
batch_size: ???
|
||||
|
||||
# Online training. Note that the online training loop adopts most of the options above apart from the
|
||||
# dataloader options. Unless otherwise specified.
|
||||
# The online training look looks something like:
|
||||
#
|
||||
# for i in range(online_steps):
|
||||
# do_online_rollout_and_update_online_buffer()
|
||||
# for j in range(online_steps_between_rollouts):
|
||||
# batch = next(dataloader_with_offline_and_online_data)
|
||||
# loss = policy(batch)
|
||||
# loss.backward()
|
||||
# optimizer.step()
|
||||
#
|
||||
online_steps: ???
|
||||
# How many episodes to collect at once when we reach the online rollout part of the training loop.
|
||||
online_rollout_n_episodes: 1
|
||||
# The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for
|
||||
# the policy. Ideally you should set this to by an even divisor or online_rollout_n_episodes.
|
||||
online_rollout_batch_size: 1
|
||||
# How many optimization steps (forward, backward, optimizer step) to do between running rollouts.
|
||||
online_steps_between_rollouts: null
|
||||
# The proportion of online samples (vs offline samples) to include in the online training batches.
|
||||
online_sampling_ratio: 0.5
|
||||
# First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1.
|
||||
online_env_seed: null
|
||||
# Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is
|
||||
# FIFO.
|
||||
online_buffer_capacity: null
|
||||
# The minimum number of frames to have in the online buffer before commencing online training.
|
||||
# If online_buffer_seed_size > online_rollout_n_episodes, the rollout will be run multiple times until the
|
||||
# seed size condition is satisfied.
|
||||
online_buffer_seed_size: 0
|
||||
# Whether to run the online rollouts asynchronously. This means we can run the online training steps in
|
||||
# parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training
|
||||
# + eval + environment rendering simultaneously.
|
||||
do_online_rollout_async: false
|
||||
|
||||
image_transforms:
|
||||
# These transforms are all using standard torchvision.transforms.v2
|
||||
# You can find out how these transformations affect images here:
|
||||
|
||||
10
lerobot/configs/env/koch_real.yaml
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
2
lerobot/configs/env/xarm.yaml
vendored
@@ -9,7 +9,7 @@ env:
|
||||
state_dim: 4
|
||||
action_dim: 4
|
||||
fps: ${fps}
|
||||
episode_length: 25
|
||||
episode_length: 200
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
|
||||
@@ -75,7 +75,7 @@ policy:
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
|
||||
102
lerobot/configs/policy/act_koch_real.yaml
Normal file
@@ -0,0 +1,102 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_koch_real.yaml` to train on real-world datasets collected on Alexander Koch's robots.
|
||||
# Compared to `act.yaml`, it contains 2 cameras (i.e. laptop, phone) instead of 1 camera (i.e. top).
|
||||
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
|
||||
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
|
||||
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_koch_real \
|
||||
# env=koch_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/koch_pick_place_lego
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.laptop:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.phone:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.laptop: [3, 480, 640]
|
||||
observation.images.phone: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.laptop: mean_std
|
||||
observation.images.phone: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -107,7 +107,7 @@ policy:
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
|
||||
@@ -103,7 +103,7 @@ policy:
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
|
||||
@@ -4,19 +4,30 @@ seed: 1
|
||||
dataset_repo_id: lerobot/xarm_lift_medium
|
||||
|
||||
training:
|
||||
offline_steps: 25000
|
||||
# TODO(alexander-soare): uncomment when online training gets reinstated
|
||||
online_steps: 0 # 25000 not implemented yet
|
||||
eval_freq: 5000
|
||||
online_steps_between_rollouts: 1
|
||||
online_sampling_ratio: 0.5
|
||||
online_env_seed: 10000
|
||||
log_freq: 100
|
||||
offline_steps: 50000
|
||||
|
||||
num_workers: 4
|
||||
|
||||
batch_size: 256
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 5000
|
||||
log_freq: 100
|
||||
|
||||
online_steps: 50000
|
||||
online_rollout_n_episodes: 1
|
||||
online_rollout_batch_size: 1
|
||||
# Note: in FOWM `online_steps_between_rollouts` is actually dynamically set to match exactly the length of
|
||||
# the last sampled episode.
|
||||
online_steps_between_rollouts: 50
|
||||
online_sampling_ratio: 0.5
|
||||
online_env_seed: 10000
|
||||
# FOWM Push uses 10000 for `online_buffer_capacity`. Given that their maximum episode length for this task
|
||||
# is 25, 10000 is approx 400 of their episodes worth. Since our episodes are about 8 times longer, we'll use
|
||||
# 80000.
|
||||
online_buffer_capacity: 80000
|
||||
|
||||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
@@ -31,6 +42,7 @@ policy:
|
||||
# Input / output structure.
|
||||
n_action_repeats: 2
|
||||
horizon: 5
|
||||
n_action_steps: 1
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
|
||||
105
lerobot/configs/policy/tdmpc_pusht_keypoints.yaml
Normal file
@@ -0,0 +1,105 @@
|
||||
# @package _global_
|
||||
|
||||
# Train with:
|
||||
#
|
||||
# python lerobot/scripts/train.py \
|
||||
# env=pusht \
|
||||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
# policy=tdmpc_pusht_keypoints \
|
||||
# eval.batch_size=50 \
|
||||
# eval.n_episodes=50 \
|
||||
# eval.use_async_envs=true \
|
||||
# device=cuda \
|
||||
# use_amp=true
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: lerobot/pusht_keypoints
|
||||
|
||||
training:
|
||||
offline_steps: 0
|
||||
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
batch_size: 256
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 10000
|
||||
log_freq: 500
|
||||
save_freq: 50000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
online_rollout_batch_size: 10
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 40000
|
||||
online_buffer_seed_size: 0
|
||||
do_online_rollout_async: false
|
||||
|
||||
delta_timestamps:
|
||||
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: tdmpc
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 5
|
||||
n_action_steps: 5
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.environment_state: [16]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.environment_state: min_max
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
image_encoder_hidden_dim: 32
|
||||
state_encoder_hidden_dim: 256
|
||||
latent_dim: 50
|
||||
q_ensemble_size: 5
|
||||
mlp_dim: 512
|
||||
# Reinforcement learning.
|
||||
discount: 0.98
|
||||
|
||||
# Inference.
|
||||
use_mpc: true
|
||||
cem_iterations: 6
|
||||
max_std: 2.0
|
||||
min_std: 0.05
|
||||
n_gaussian_samples: 512
|
||||
n_pi_samples: 51
|
||||
uncertainty_regularizer_coeff: 1.0
|
||||
n_elites: 50
|
||||
elite_weighting_temperature: 0.5
|
||||
gaussian_mean_momentum: 0.1
|
||||
|
||||
# Training and loss computation.
|
||||
max_random_shift_ratio: 0.0476
|
||||
# Loss coefficients.
|
||||
reward_coeff: 0.5
|
||||
expectile_weight: 0.9
|
||||
value_coeff: 0.1
|
||||
consistency_coeff: 20.0
|
||||
advantage_scaling: 3.0
|
||||
pi_coeff: 0.5
|
||||
temporal_decay_coeff: 0.5
|
||||
# Target model.
|
||||
target_model_momentum: 0.995
|
||||
39
lerobot/configs/robot/koch.yaml
Normal file
@@ -0,0 +1,39 @@
|
||||
_target_: lerobot.common.robot_devices.robots.koch.KochRobot
|
||||
calibration_path: .cache/calibration/koch.pkl
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0031751
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl330-m077"]
|
||||
shoulder_lift: [2, "xl330-m077"]
|
||||
elbow_flex: [3, "xl330-m077"]
|
||||
wrist_flex: [4, "xl330-m077"]
|
||||
wrist_roll: [5, "xl330-m077"]
|
||||
gripper: [6, "xl330-m077"]
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0032081
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl430-w250"]
|
||||
shoulder_lift: [2, "xl430-w250"]
|
||||
elbow_flex: [3, "xl330-m288"]
|
||||
wrist_flex: [4, "xl330-m288"]
|
||||
wrist_roll: [5, "xl330-m288"]
|
||||
gripper: [6, "xl330-m288"]
|
||||
cameras:
|
||||
laptop:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 0
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
phone:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 1
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
862
lerobot/scripts/control_robot.py
Normal file
@@ -0,0 +1,862 @@
|
||||
"""
|
||||
Utilities to control a robot.
|
||||
|
||||
Useful to record a dataset, replay a recorded episode, run the policy on your robot
|
||||
and record an evaluation dataset, and to recalibrate your robot if needed.
|
||||
|
||||
Examples of usage:
|
||||
|
||||
- Recalibrate your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py calibrate
|
||||
```
|
||||
|
||||
- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate
|
||||
|
||||
# Remove the cameras from the robot definition. They are not used in 'teleoperate' anyway.
|
||||
python lerobot/scripts/control_robot.py teleoperate --robot-overrides '~cameras'
|
||||
```
|
||||
|
||||
- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--fps 30
|
||||
```
|
||||
|
||||
- Record one episode in order to test replay:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root tmp/data \
|
||||
--repo-id $USER/koch_test \
|
||||
--num-episodes 1 \
|
||||
--run-compute-stats 0
|
||||
```
|
||||
|
||||
- Visualize dataset:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
--root tmp/data \
|
||||
--repo-id $USER/koch_test \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
- Replay this test episode:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--fps 30 \
|
||||
--root tmp/data \
|
||||
--repo-id $USER/koch_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
- Record a full dataset in order to train a policy, with 2 seconds of warmup,
|
||||
30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id $USER/koch_pick_place_lego \
|
||||
--num-episodes 50 \
|
||||
--warmup-time-s 2 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 10
|
||||
```
|
||||
|
||||
**NOTE**: You can use your keyboard to control data recording flow.
|
||||
- Tap right arrow key '->' to early exit while recording an episode and go to resseting the environment.
|
||||
- Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode.
|
||||
- Tap left arrow key '<-' to early exit and re-record the current episode.
|
||||
- Tap escape key 'esc' to stop the data recording.
|
||||
This might require a sudo permission to allow your terminal to monitor keyboard events.
|
||||
|
||||
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
|
||||
To avoid resuming by deleting the dataset, use `--force-override 1`.
|
||||
|
||||
- Train on this dataset with the ACT policy:
|
||||
```bash
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
policy=act_koch_real \
|
||||
env=koch_real \
|
||||
dataset_repo_id=$USER/koch_pick_place_lego \
|
||||
hydra.run.dir=outputs/train/act_koch_real
|
||||
```
|
||||
|
||||
- Run the pretrained policy on the robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id $USER/eval_act_koch_real \
|
||||
--num-episodes 10 \
|
||||
--warmup-time-s 2 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 10
|
||||
-p outputs/train/act_koch_real/checkpoints/080000/pretrained_model
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import tqdm
|
||||
from omegaconf import DictConfig
|
||||
from PIL import Image
|
||||
from termcolor import colored
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||
from lerobot.scripts.push_dataset_to_hub import (
|
||||
push_dataset_card_to_hub,
|
||||
push_meta_data_to_hub,
|
||||
push_videos_to_hub,
|
||||
save_meta_data,
|
||||
)
|
||||
|
||||
########################################################################################
|
||||
# Utilities
|
||||
########################################################################################
|
||||
|
||||
|
||||
def say(text, blocking=False):
|
||||
# Check if mac, linux, or windows.
|
||||
if platform.system() == "Darwin":
|
||||
cmd = f'say "{text}"'
|
||||
elif platform.system() == "Linux":
|
||||
cmd = f'spd-say "{text}"'
|
||||
elif platform.system() == "Windows":
|
||||
cmd = (
|
||||
'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
|
||||
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
|
||||
)
|
||||
|
||||
if not blocking and platform.system() in ["Darwin", "Linux"]:
|
||||
# TODO(rcadene): Make it work for Windows
|
||||
# Use the ampersand to run command in the background
|
||||
cmd += " &"
|
||||
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
|
||||
img = Image.fromarray(img_tensor.numpy())
|
||||
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
img.save(str(path), quality=100)
|
||||
|
||||
|
||||
def busy_wait(seconds):
|
||||
# Significantly more accurate than `time.sleep`, and mendatory for our use case,
|
||||
# but it consumes CPU cycles.
|
||||
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise
|
||||
end_time = time.perf_counter() + seconds
|
||||
while time.perf_counter() < end_time:
|
||||
pass
|
||||
|
||||
|
||||
def none_or_int(value):
|
||||
if value == "None":
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||
log_items = []
|
||||
if episode_index is not None:
|
||||
log_items += [f"ep:{episode_index}"]
|
||||
if frame_index is not None:
|
||||
log_items += [f"frame:{frame_index}"]
|
||||
|
||||
def log_dt(shortname, dt_val_s):
|
||||
nonlocal log_items
|
||||
log_items += [f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"]
|
||||
|
||||
# total step time displayed in milliseconds and its frequency
|
||||
log_dt("dt", dt_s)
|
||||
|
||||
for name in robot.leader_arms:
|
||||
key = f"read_leader_{name}_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtRlead", robot.logs[key])
|
||||
|
||||
for name in robot.follower_arms:
|
||||
key = f"write_follower_{name}_goal_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtWfoll", robot.logs[key])
|
||||
|
||||
key = f"read_follower_{name}_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtRfoll", robot.logs[key])
|
||||
|
||||
for name in robot.cameras:
|
||||
key = f"read_camera_{name}_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt(f"dtR{name}", robot.logs[key])
|
||||
|
||||
info_str = " ".join(log_items)
|
||||
if fps is not None:
|
||||
actual_fps = 1 / dt_s
|
||||
if actual_fps < fps - 1:
|
||||
info_str = colored(info_str, "yellow")
|
||||
logging.info(info_str)
|
||||
|
||||
|
||||
@cache
|
||||
def is_headless():
|
||||
"""Detects if python is running without a monitor."""
|
||||
try:
|
||||
import pynput # noqa
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
print(
|
||||
"Error trying to import pynput. Switching to headless mode. "
|
||||
"As a result, the video stream from the cameras won't be shown, "
|
||||
"and you won't be able to change the control flow with keyboards. "
|
||||
"For more info, see traceback below.\n"
|
||||
)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
return True
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Control modes
|
||||
########################################################################################
|
||||
|
||||
|
||||
def calibrate(robot: Robot):
|
||||
if robot.calibration_path.exists():
|
||||
print(f"Removing '{robot.calibration_path}'")
|
||||
robot.calibration_path.unlink()
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
# Calling `connect` automatically runs calibration
|
||||
# when the calibration file is missing
|
||||
robot.connect()
|
||||
|
||||
|
||||
def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
start_teleop_t = time.perf_counter()
|
||||
while True:
|
||||
start_loop_t = time.perf_counter()
|
||||
robot.teleop_step()
|
||||
|
||||
if fps is not None:
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
|
||||
break
|
||||
|
||||
|
||||
def record(
|
||||
robot: Robot,
|
||||
policy: torch.nn.Module | None = None,
|
||||
hydra_cfg: DictConfig | None = None,
|
||||
fps: int | None = None,
|
||||
root="data",
|
||||
repo_id="lerobot/debug",
|
||||
warmup_time_s=2,
|
||||
episode_time_s=10,
|
||||
reset_time_s=5,
|
||||
num_episodes=50,
|
||||
video=True,
|
||||
run_compute_stats=True,
|
||||
push_to_hub=True,
|
||||
tags=None,
|
||||
num_image_writers=8,
|
||||
force_override=False,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
# TODO(rcadene): Clean this function via decomposition in higher level functions
|
||||
|
||||
_, dataset_name = repo_id.split("/")
|
||||
if dataset_name.startswith("eval_") and policy is None:
|
||||
raise ValueError(
|
||||
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
|
||||
)
|
||||
|
||||
if not video:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
local_dir = Path(root) / repo_id
|
||||
if local_dir.exists() and force_override:
|
||||
shutil.rmtree(local_dir)
|
||||
|
||||
episodes_dir = local_dir / "episodes"
|
||||
episodes_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
videos_dir = local_dir / "videos"
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Logic to resume data recording
|
||||
rec_info_path = episodes_dir / "data_recording_info.json"
|
||||
if rec_info_path.exists():
|
||||
with open(rec_info_path) as f:
|
||||
rec_info = json.load(f)
|
||||
episode_index = rec_info["last_episode_index"] + 1
|
||||
else:
|
||||
episode_index = 0
|
||||
|
||||
if is_headless():
|
||||
logging.info(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
exit_early = False
|
||||
rerecord_episode = False
|
||||
stop_recording = False
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
if not is_headless():
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
nonlocal exit_early, rerecord_episode, stop_recording
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
exit_early = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
rerecord_episode = True
|
||||
exit_early = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
stop_recording = True
|
||||
exit_early = True
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
# Load policy if any
|
||||
if policy is not None:
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
|
||||
# override fps using policy fps
|
||||
fps = hydra_cfg.env.fps
|
||||
|
||||
# Execute a few seconds without recording data, to give times
|
||||
# to the robot devices to connect and start synchronizing.
|
||||
timestamp = 0
|
||||
start_warmup_t = time.perf_counter()
|
||||
is_warmup_print = False
|
||||
while timestamp < warmup_time_s:
|
||||
if not is_warmup_print:
|
||||
logging.info("Warming up (no data recording)")
|
||||
say("Warming up")
|
||||
is_warmup_print = True
|
||||
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if policy is None:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
if not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_warmup_t
|
||||
|
||||
# Save images using threads to reach high fps (30 and more)
|
||||
# Using `with` to exist smoothly if an execption is raised.
|
||||
# Using only 4 worker threads to avoid blocking the main thread.
|
||||
futures = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
||||
# Start recording all episodes
|
||||
while episode_index < num_episodes:
|
||||
logging.info(f"Recording episode {episode_index}")
|
||||
say(f"Recording episode {episode_index}")
|
||||
ep_dict = {}
|
||||
frame_index = 0
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < episode_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if policy is None:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
not_image_keys = [key for key in observation if "image" not in key]
|
||||
|
||||
for key in image_keys:
|
||||
futures += [
|
||||
executor.submit(
|
||||
save_image, observation[key], key, frame_index, episode_index, videos_dir
|
||||
)
|
||||
]
|
||||
|
||||
if not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
for key in not_image_keys:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
ep_dict[key].append(observation[key])
|
||||
|
||||
if policy is not None:
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if device.type == "cuda" and hydra_cfg.use_amp
|
||||
else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
# Move to cpu, if not already the case
|
||||
action = action.to("cpu")
|
||||
|
||||
# Order the robot to move
|
||||
robot.send_action(action)
|
||||
action = {"action": action}
|
||||
|
||||
for key in action:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
ep_dict[key].append(action[key])
|
||||
|
||||
frame_index += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if exit_early:
|
||||
exit_early = False
|
||||
break
|
||||
|
||||
if not stop_recording:
|
||||
# Start resetting env while the executor are finishing
|
||||
logging.info("Reset the environment")
|
||||
say("Reset the environment")
|
||||
|
||||
timestamp = 0
|
||||
start_vencod_t = time.perf_counter()
|
||||
|
||||
# During env reset we save the data and encode the videos
|
||||
num_frames = frame_index
|
||||
|
||||
for key in image_keys:
|
||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
video_path = local_dir / "videos" / fname
|
||||
if video_path.exists():
|
||||
video_path.unlink()
|
||||
# Store the reference to the video frame, even tho the videos are not yet encoded
|
||||
ep_dict[key] = []
|
||||
for i in range(num_frames):
|
||||
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
|
||||
|
||||
for key in not_image_keys:
|
||||
ep_dict[key] = torch.stack(ep_dict[key])
|
||||
|
||||
for key in action:
|
||||
ep_dict[key] = torch.stack(ep_dict[key])
|
||||
|
||||
ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
ep_dict["next.done"] = done
|
||||
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||
print("Saving episode dictionary...")
|
||||
torch.save(ep_dict, ep_path)
|
||||
|
||||
rec_info = {
|
||||
"last_episode_index": episode_index,
|
||||
}
|
||||
with open(rec_info_path, "w") as f:
|
||||
json.dump(rec_info, f)
|
||||
|
||||
is_last_episode = stop_recording or (episode_index == (num_episodes - 1))
|
||||
|
||||
# Wait if necessary
|
||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||
while timestamp < reset_time_s and not is_last_episode:
|
||||
time.sleep(1)
|
||||
timestamp = time.perf_counter() - start_vencod_t
|
||||
pbar.update(1)
|
||||
if exit_early:
|
||||
exit_early = False
|
||||
break
|
||||
|
||||
# Skip updating episode index which forces re-recording episode
|
||||
if rerecord_episode:
|
||||
rerecord_episode = False
|
||||
continue
|
||||
|
||||
episode_index += 1
|
||||
|
||||
if is_last_episode:
|
||||
logging.info("Done recording")
|
||||
say("Done recording", blocking=True)
|
||||
if not is_headless():
|
||||
listener.stop()
|
||||
|
||||
logging.info("Waiting for threads writing the images on disk to terminate...")
|
||||
for _ in tqdm.tqdm(
|
||||
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
|
||||
):
|
||||
pass
|
||||
break
|
||||
|
||||
robot.disconnect()
|
||||
if not is_headless():
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
num_episodes = episode_index
|
||||
|
||||
logging.info("Encoding videos")
|
||||
say("Encoding videos")
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||
for key in image_keys:
|
||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
video_path = local_dir / "videos" / fname
|
||||
if video_path.exists():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
logging.info("Concatenating episodes")
|
||||
ep_dicts = []
|
||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||
ep_dict = torch.load(ep_path)
|
||||
ep_dicts.append(ep_dict)
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
hf_dataset=hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
info=info,
|
||||
videos_dir=videos_dir,
|
||||
)
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
say("Computing dataset statistics")
|
||||
stats = compute_stats(lerobot_dataset)
|
||||
lerobot_dataset.stats = stats
|
||||
else:
|
||||
stats = {}
|
||||
logging.info("Skipping computation of the dataset statistics")
|
||||
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
if push_to_hub:
|
||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
|
||||
if video:
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
|
||||
logging.info("Exiting")
|
||||
say("Exiting")
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
local_dir = Path(root) / repo_id
|
||||
if not local_dir.exists():
|
||||
raise ValueError(local_dir)
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
items = dataset.hf_dataset.select_columns("action")
|
||||
from_idx = dataset.episode_data_index["from"][episode].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode].item()
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
logging.info("Replaying episode")
|
||||
say("Replaying episode", blocking=True)
|
||||
for idx in range(from_idx, to_idx):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = items[idx]["action"]
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest="mode", required=True)
|
||||
|
||||
# Set common options for all the subparsers
|
||||
base_parser = argparse.ArgumentParser(add_help=False)
|
||||
base_parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
base_parser.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
|
||||
parser_calib = subparsers.add_parser("calibrate", parents=[base_parser])
|
||||
|
||||
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
|
||||
parser_teleop.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
|
||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||
parser_record.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default="data",
|
||||
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot/test",
|
||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--warmup-time-s",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--episode-time-s",
|
||||
type=int,
|
||||
default=60,
|
||||
help="Number of seconds for data recording for each episode.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--reset-time-s",
|
||||
type=int,
|
||||
default=60,
|
||||
help="Number of seconds for resetting the environment after each episode.",
|
||||
)
|
||||
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
|
||||
parser_record.add_argument(
|
||||
"--run-compute-stats",
|
||||
type=int,
|
||||
default=1,
|
||||
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--push-to-hub",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Upload dataset to Hugging Face hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--tags",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Add tags to your dataset on the hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--num-image-writers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of threads writing the frames as png images on disk. Don't set too much as you might get unstable fps due to main thread being blocked.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--force-override",
|
||||
type=int,
|
||||
default=0,
|
||||
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
type=str,
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`."
|
||||
),
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--policy-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default="data",
|
||||
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot/test",
|
||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||
)
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
init_logging()
|
||||
|
||||
control_mode = args.mode
|
||||
robot_path = args.robot_path
|
||||
robot_overrides = args.robot_overrides
|
||||
kwargs = vars(args)
|
||||
del kwargs["mode"]
|
||||
del kwargs["robot_path"]
|
||||
del kwargs["robot_overrides"]
|
||||
|
||||
robot_cfg = init_hydra_config(robot_path, robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
|
||||
if control_mode == "calibrate":
|
||||
calibrate(robot, **kwargs)
|
||||
|
||||
elif control_mode == "teleoperate":
|
||||
teleoperate(robot, **kwargs)
|
||||
|
||||
elif control_mode == "record":
|
||||
pretrained_policy_name_or_path = args.pretrained_policy_name_or_path
|
||||
policy_overrides = args.policy_overrides
|
||||
del kwargs["pretrained_policy_name_or_path"]
|
||||
del kwargs["policy_overrides"]
|
||||
|
||||
policy_cfg = None
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||
policy_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||
policy = make_policy(hydra_cfg=policy_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
record(robot, policy, policy_cfg, **kwargs)
|
||||
else:
|
||||
record(robot, **kwargs)
|
||||
|
||||
elif control_mode == "replay":
|
||||
replay(robot, **kwargs)
|
||||
|
||||
if robot.is_connected:
|
||||
# Disconnect manually to avoid a "Core dump" during process
|
||||
# termination due to camera threads not properly exiting.
|
||||
robot.disconnect()
|
||||
@@ -56,16 +56,13 @@ import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value, concatenate_datasets
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from PIL import Image as PILImage
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.logger import log_output_dir
|
||||
@@ -318,41 +315,17 @@ def eval_policy(
|
||||
rollout_data,
|
||||
done_indices,
|
||||
start_episode_index=batch_ix * env.num_envs,
|
||||
start_data_index=(
|
||||
0 if episode_data is None else (episode_data["episode_data_index"]["to"][-1].item())
|
||||
),
|
||||
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
|
||||
fps=env.unwrapped.metadata["render_fps"],
|
||||
)
|
||||
if episode_data is None:
|
||||
episode_data = this_episode_data
|
||||
else:
|
||||
# Some sanity checks to make sure we are not correctly compiling the data.
|
||||
assert (
|
||||
episode_data["hf_dataset"]["episode_index"][-1] + 1
|
||||
== this_episode_data["hf_dataset"]["episode_index"][0]
|
||||
)
|
||||
assert (
|
||||
episode_data["hf_dataset"]["index"][-1] + 1 == this_episode_data["hf_dataset"]["index"][0]
|
||||
)
|
||||
assert torch.equal(
|
||||
episode_data["episode_data_index"]["to"][-1],
|
||||
this_episode_data["episode_data_index"]["from"][0],
|
||||
)
|
||||
# Some sanity checks to make sure we are correctly compiling the data.
|
||||
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
|
||||
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
|
||||
# Concatenate the episode data.
|
||||
episode_data = {
|
||||
"hf_dataset": concatenate_datasets(
|
||||
[episode_data["hf_dataset"], this_episode_data["hf_dataset"]]
|
||||
),
|
||||
"episode_data_index": {
|
||||
k: torch.cat(
|
||||
[
|
||||
episode_data["episode_data_index"][k],
|
||||
this_episode_data["episode_data_index"][k],
|
||||
]
|
||||
)
|
||||
for k in ["from", "to"]
|
||||
},
|
||||
}
|
||||
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
|
||||
|
||||
# Maybe render video for visualization.
|
||||
if max_episodes_rendered > 0 and len(ep_frames) > 0:
|
||||
@@ -434,89 +407,39 @@ def _compile_episode_data(
|
||||
Similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`).
|
||||
"""
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
total_frames = 0
|
||||
data_index_from = start_data_index
|
||||
for ep_ix in range(rollout_data["action"].shape[0]):
|
||||
num_frames = done_indices[ep_ix].item() + 1 # + 1 to include the first done frame
|
||||
# + 2 to include the first done frame and the last observation frame.
|
||||
num_frames = done_indices[ep_ix].item() + 2
|
||||
total_frames += num_frames
|
||||
|
||||
# TODO(rcadene): We need to add a missing last frame which is the observation
|
||||
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
|
||||
ep_dict = {
|
||||
"action": rollout_data["action"][ep_ix, :num_frames],
|
||||
"episode_index": torch.tensor([start_episode_index + ep_ix] * num_frames),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
"next.done": rollout_data["done"][ep_ix, :num_frames],
|
||||
"next.reward": rollout_data["reward"][ep_ix, :num_frames].type(torch.float32),
|
||||
"action": rollout_data["action"][ep_ix, : num_frames - 1],
|
||||
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
|
||||
"frame_index": torch.arange(0, num_frames - 1, 1),
|
||||
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
|
||||
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
|
||||
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
|
||||
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
|
||||
}
|
||||
|
||||
# For the last observation frame, all other keys will just be copy padded.
|
||||
for k in ep_dict:
|
||||
ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]])
|
||||
|
||||
for key in rollout_data["observation"]:
|
||||
ep_dict[key] = rollout_data["observation"][key][ep_ix][:num_frames]
|
||||
ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames]
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(data_index_from)
|
||||
episode_data_index["to"].append(data_index_from + num_frames)
|
||||
|
||||
data_index_from += num_frames
|
||||
|
||||
data_dict = {}
|
||||
for key in ep_dicts[0]:
|
||||
if "image" not in key:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for img in ep_dict[key]:
|
||||
# sanity check that images are channel first
|
||||
c, h, w = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
|
||||
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
|
||||
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
|
||||
|
||||
# from float32 in range [0,1] to uint8 in range [0,255]
|
||||
img *= 255
|
||||
img = img.type(torch.uint8)
|
||||
|
||||
# convert to channel last and numpy as expected by PIL
|
||||
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
|
||||
|
||||
data_dict[key].append(img)
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
|
||||
episode_data_index["from"] = torch.tensor(episode_data_index["from"])
|
||||
episode_data_index["to"] = torch.tensor(episode_data_index["to"])
|
||||
|
||||
# TODO(rcadene): clean this
|
||||
features = {}
|
||||
for key in rollout_data["observation"]:
|
||||
if "image" in key:
|
||||
features[key] = Image()
|
||||
else:
|
||||
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
|
||||
features.update(
|
||||
{
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
}
|
||||
)
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return {
|
||||
"hf_dataset": hf_dataset,
|
||||
"episode_data_index": episode_data_index,
|
||||
}
|
||||
return data_dict
|
||||
|
||||
|
||||
def main(
|
||||
@@ -578,6 +501,29 @@ def main(
|
||||
logging.info("End of eval")
|
||||
|
||||
|
||||
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
||||
try:
|
||||
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||
)
|
||||
else:
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||
)
|
||||
|
||||
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||
pretrained_policy_path = Path(pretrained_policy_name_or_path)
|
||||
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||
raise ValueError(
|
||||
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
return pretrained_policy_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
@@ -619,27 +565,9 @@ if __name__ == "__main__":
|
||||
if args.pretrained_policy_name_or_path is None:
|
||||
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
|
||||
else:
|
||||
try:
|
||||
pretrained_policy_path = Path(
|
||||
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
|
||||
)
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||
)
|
||||
else:
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||
)
|
||||
|
||||
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
|
||||
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||
raise ValueError(
|
||||
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
pretrained_policy_path = get_pretrained_policy_path(
|
||||
args.pretrained_policy_name_or_path, revision=args.revision
|
||||
)
|
||||
|
||||
main(
|
||||
pretrained_policy_path=pretrained_policy_path,
|
||||
|
||||
@@ -40,60 +40,6 @@ python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-format umi_zarr \
|
||||
--repo-id lerobot/umi_cup_in_the_wild
|
||||
```
|
||||
|
||||
**WARNING: Updating an existing dataset**
|
||||
|
||||
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
|
||||
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
|
||||
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
|
||||
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
|
||||
codebase won't be affected by your change and backward compatibility is maintained.
|
||||
|
||||
For instance, Pusht has many versions to maintain backward compatibility between LeRobot codebase versions:
|
||||
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
|
||||
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
|
||||
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
|
||||
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
|
||||
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
|
||||
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version
|
||||
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
|
||||
|
||||
However, you will need to update the version of ALL the other datasets so that they have the new
|
||||
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
|
||||
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
|
||||
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
from huggingface_hub import create_branch, hf_hub_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" # makes it easier to see the print-out below
|
||||
|
||||
NEW_CODEBASE_VERSION = "v1.5" # REPLACE THIS WITH YOUR DESIRED VERSION
|
||||
|
||||
for repo_id in available_datasets:
|
||||
# First check if the newer version already exists.
|
||||
try:
|
||||
hf_hub_download(
|
||||
repo_id=repo_id, repo_type="dataset", filename=".gitattributes", revision=NEW_CODEBASE_VERSION
|
||||
)
|
||||
print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.")
|
||||
print("Exiting early")
|
||||
break
|
||||
except RepositoryNotFoundError:
|
||||
# Now create a branch.
|
||||
create_branch(repo_id, repo_type="dataset", branch=NEW_CODEBASE_VERSION, revision=CODEBASE_VERSION)
|
||||
print(f"{repo_id} successfully updated")
|
||||
|
||||
```
|
||||
|
||||
On the other hand, if you are pushing a new dataset, you don't need to worry about any of the instructions
|
||||
above, nor to be compatible with previous codebase versions.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -104,12 +50,13 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi, create_branch
|
||||
from huggingface_hub import HfApi
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import flatten_dict
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
from lerobot.common.datasets.utils import create_branch, create_lerobot_dataset_card, flatten_dict
|
||||
|
||||
|
||||
def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||
@@ -167,6 +114,14 @@ def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str
|
||||
)
|
||||
|
||||
|
||||
def push_dataset_card_to_hub(
|
||||
repo_id: str, revision: str | None, tags: list | None = None, text: str | None = None
|
||||
):
|
||||
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
|
||||
card = create_lerobot_dataset_card(tags=tags, text=text)
|
||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
|
||||
|
||||
|
||||
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
|
||||
"""Expect mp4 files to be all stored in a single "videos" directory.
|
||||
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
|
||||
@@ -194,22 +149,20 @@ def push_dataset_to_hub(
|
||||
num_workers: int = 8,
|
||||
episodes: list[int] | None = None,
|
||||
force_override: bool = False,
|
||||
resume: bool = False,
|
||||
cache_dir: Path = Path("/tmp"),
|
||||
tests_data_dir: Path | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
# Check repo_id is well formated
|
||||
if len(repo_id.split("/")) != 2:
|
||||
raise ValueError(
|
||||
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'."
|
||||
)
|
||||
check_repo_id(repo_id)
|
||||
user_id, dataset_id = repo_id.split("/")
|
||||
|
||||
# Robustify when `raw_dir` is str instead of Path
|
||||
raw_dir = Path(raw_dir)
|
||||
if not raw_dir.exists():
|
||||
raise NotADirectoryError(
|
||||
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub:"
|
||||
f"python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw"
|
||||
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub: "
|
||||
f"`python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw`"
|
||||
)
|
||||
|
||||
if local_dir:
|
||||
@@ -227,7 +180,7 @@ def push_dataset_to_hub(
|
||||
if local_dir.exists():
|
||||
if force_override:
|
||||
shutil.rmtree(local_dir)
|
||||
else:
|
||||
elif not resume:
|
||||
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
|
||||
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
@@ -245,7 +198,7 @@ def push_dataset_to_hub(
|
||||
# convert dataset from original raw format to LeRobot format
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||
raw_dir, videos_dir, fps, video, episodes
|
||||
raw_dir, videos_dir, fps, video, episodes, encoding
|
||||
)
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
@@ -268,6 +221,7 @@ def push_dataset_to_hub(
|
||||
if push_to_hub:
|
||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_dataset_card_to_hub(repo_id, revision="main")
|
||||
if video:
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
@@ -368,6 +322,12 @@ def main():
|
||||
default=0,
|
||||
help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=int,
|
||||
default=0,
|
||||
help="When set to 1, resumes a previous run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
type=Path,
|
||||
|
||||
@@ -15,20 +15,25 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from threading import Lock
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
|
||||
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
@@ -107,6 +112,7 @@ def update_policy(
|
||||
grad_scaler: GradScaler,
|
||||
lr_scheduler=None,
|
||||
use_amp: bool = False,
|
||||
lock=None,
|
||||
):
|
||||
"""Returns a dictionary of items for logging."""
|
||||
start_time = time.perf_counter()
|
||||
@@ -129,7 +135,8 @@ def update_policy(
|
||||
|
||||
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||
grad_scaler.step(optimizer)
|
||||
with lock if lock is not None else nullcontext():
|
||||
grad_scaler.step(optimizer)
|
||||
# Updates the scale for next iteration.
|
||||
grad_scaler.update()
|
||||
|
||||
@@ -149,11 +156,12 @@ def update_policy(
|
||||
"update_s": time.perf_counter() - start_time,
|
||||
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||
}
|
||||
info.update({k: v for k, v in output_dict.items() if k not in info})
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||
def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
|
||||
loss = info["loss"]
|
||||
grad_norm = info["grad_norm"]
|
||||
lr = info["lr"]
|
||||
@@ -187,12 +195,12 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||
info["num_samples"] = num_samples
|
||||
info["num_episodes"] = num_episodes
|
||||
info["num_epochs"] = num_epochs
|
||||
info["is_offline"] = is_offline
|
||||
info["is_online"] = is_online
|
||||
|
||||
logger.log_dict(info, step, mode="train")
|
||||
|
||||
|
||||
def log_eval_info(logger, info, step, cfg, dataset, is_offline):
|
||||
def log_eval_info(logger, info, step, cfg, dataset, is_online):
|
||||
eval_s = info["eval_s"]
|
||||
avg_sum_reward = info["avg_sum_reward"]
|
||||
pc_success = info["pc_success"]
|
||||
@@ -221,7 +229,7 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
|
||||
info["num_samples"] = num_samples
|
||||
info["num_episodes"] = num_episodes
|
||||
info["num_epochs"] = num_epochs
|
||||
info["is_offline"] = is_offline
|
||||
info["is_online"] = is_online
|
||||
|
||||
logger.log_dict(info, step, mode="eval")
|
||||
|
||||
@@ -234,6 +242,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
init_logging()
|
||||
|
||||
if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
|
||||
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
||||
|
||||
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
|
||||
# to check for any differences between the provided config and the checkpoint's config.
|
||||
if cfg.resume:
|
||||
@@ -272,15 +283,13 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
cfg.resume = True
|
||||
elif Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists."
|
||||
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If "
|
||||
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
|
||||
)
|
||||
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
|
||||
if cfg.training.online_steps > 0:
|
||||
raise NotImplementedError("Online training is not implemented yet.")
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
# Check device is available
|
||||
@@ -335,7 +344,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# Note: this helper will be used in offline and online training loops.
|
||||
def evaluate_and_checkpoint_if_needed(step):
|
||||
def evaluate_and_checkpoint_if_needed(step, is_online):
|
||||
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||
step_identifier = f"{step:0{_num_digits}d}"
|
||||
|
||||
@@ -351,7 +360,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_online=is_online)
|
||||
if cfg.wandb.enable:
|
||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
@@ -395,8 +404,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
offline_step = 0
|
||||
for _ in range(step, cfg.training.offline_steps):
|
||||
if step == 0:
|
||||
if offline_step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
@@ -419,13 +429,207 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
train_info["dataloading_s"] = dataloading_s
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
# so we pass in step + 1.
|
||||
evaluate_and_checkpoint_if_needed(step + 1)
|
||||
evaluate_and_checkpoint_if_needed(step + 1, is_online=False)
|
||||
|
||||
step += 1
|
||||
offline_step += 1 # noqa: SIM113
|
||||
|
||||
if cfg.training.online_steps == 0:
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
logging.info("End of training")
|
||||
return
|
||||
|
||||
# Online training.
|
||||
|
||||
# Create an env dedicated to online episodes collection from policy rollout.
|
||||
online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
|
||||
resolve_delta_timestamps(cfg)
|
||||
online_buffer_path = logger.log_dir / "online_buffer"
|
||||
if cfg.resume and not online_buffer_path.exists():
|
||||
# If we are resuming a run, we default to the data shapes and buffer capacity from the saved online
|
||||
# buffer.
|
||||
logging.warning(
|
||||
"When online training is resumed, we load the latest online buffer from the prior run, "
|
||||
"and this might not coincide with the state of the buffer as it was at the moment the checkpoint "
|
||||
"was made. This is because the online buffer is updated on disk during training, independently "
|
||||
"of our explicit checkpointing mechanisms."
|
||||
)
|
||||
online_dataset = OnlineBuffer(
|
||||
online_buffer_path,
|
||||
data_spec={
|
||||
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.input_shapes.items()},
|
||||
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()},
|
||||
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
|
||||
"next.done": {"shape": (), "dtype": np.dtype("?")},
|
||||
"next.success": {"shape": (), "dtype": np.dtype("?")},
|
||||
},
|
||||
buffer_capacity=cfg.training.online_buffer_capacity,
|
||||
fps=online_env.unwrapped.metadata["render_fps"],
|
||||
delta_timestamps=cfg.training.delta_timestamps,
|
||||
)
|
||||
|
||||
# If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this
|
||||
# makes it possible to do online rollouts in parallel with training updates).
|
||||
online_rollout_policy = deepcopy(policy) if cfg.training.do_online_rollout_async else policy
|
||||
|
||||
# Create dataloader for online training.
|
||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||
sampler_weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
|
||||
online_dataset=online_dataset,
|
||||
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
|
||||
# this final observation in the offline datasets, but we might add them in future.
|
||||
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||
)
|
||||
sampler = torch.utils.data.WeightedRandomSampler(
|
||||
sampler_weights,
|
||||
num_samples=len(concat_dataset),
|
||||
replacement=True,
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
concat_dataset,
|
||||
batch_size=cfg.training.batch_size,
|
||||
num_workers=cfg.training.num_workers,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
# Lock and thread pool executor for asynchronous online rollouts. When asynchronous mode is disabled,
|
||||
# these are still used but effectively do nothing.
|
||||
lock = Lock()
|
||||
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
|
||||
# parallelization of rollouts is handled within the job.
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
online_step = 0
|
||||
online_rollout_s = 0 # time take to do online rollout
|
||||
update_online_buffer_s = 0 # time taken to update the online buffer with the online rollout data
|
||||
# Time taken waiting for the online buffer to finish being updated. This is relevant when using the async
|
||||
# online rollout option.
|
||||
await_update_online_buffer_s = 0
|
||||
rollout_start_seed = cfg.training.online_env_seed
|
||||
|
||||
while True:
|
||||
if online_step == cfg.training.online_steps:
|
||||
break
|
||||
|
||||
if online_step == 0:
|
||||
logging.info("Start online training by interacting with environment")
|
||||
|
||||
def sample_trajectory_and_update_buffer():
|
||||
nonlocal rollout_start_seed
|
||||
with lock:
|
||||
online_rollout_policy.load_state_dict(policy.state_dict())
|
||||
online_rollout_policy.eval()
|
||||
start_rollout_time = time.perf_counter()
|
||||
with torch.no_grad():
|
||||
eval_info = eval_policy(
|
||||
online_env,
|
||||
online_rollout_policy,
|
||||
n_episodes=cfg.training.online_rollout_n_episodes,
|
||||
max_episodes_rendered=min(10, cfg.training.online_rollout_n_episodes),
|
||||
videos_dir=logger.log_dir / "online_rollout_videos",
|
||||
return_episode_data=True,
|
||||
start_seed=(
|
||||
rollout_start_seed := (rollout_start_seed + cfg.training.batch_size) % 1000000
|
||||
),
|
||||
)
|
||||
online_rollout_s = time.perf_counter() - start_rollout_time
|
||||
|
||||
with lock:
|
||||
start_update_buffer_time = time.perf_counter()
|
||||
online_dataset.add_data(eval_info["episodes"])
|
||||
|
||||
# Update the concatenated dataset length used during sampling.
|
||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||
|
||||
# Update the sampling weights.
|
||||
sampler.weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
|
||||
online_dataset=online_dataset,
|
||||
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
|
||||
# this final observation in the offline datasets, but we might add them in future.
|
||||
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||
)
|
||||
sampler.num_samples = len(concat_dataset)
|
||||
|
||||
update_online_buffer_s = time.perf_counter() - start_update_buffer_time
|
||||
|
||||
return online_rollout_s, update_online_buffer_s
|
||||
|
||||
future = executor.submit(sample_trajectory_and_update_buffer)
|
||||
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
|
||||
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
|
||||
if (
|
||||
not cfg.training.do_online_rollout_async
|
||||
or len(online_dataset) <= cfg.training.online_buffer_seed_size
|
||||
):
|
||||
online_rollout_s, update_online_buffer_s = future.result()
|
||||
|
||||
if len(online_dataset) <= cfg.training.online_buffer_seed_size:
|
||||
logging.info(
|
||||
f"Seeding online buffer: {len(online_dataset)}/{cfg.training.online_buffer_seed_size}"
|
||||
)
|
||||
continue
|
||||
|
||||
policy.train()
|
||||
for _ in range(cfg.training.online_steps_between_rollouts):
|
||||
with lock:
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = update_policy(
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.training.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
lock=lock,
|
||||
)
|
||||
|
||||
train_info["dataloading_s"] = dataloading_s
|
||||
train_info["online_rollout_s"] = online_rollout_s
|
||||
train_info["update_online_buffer_s"] = update_online_buffer_s
|
||||
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
|
||||
with lock:
|
||||
train_info["online_buffer_size"] = len(online_dataset)
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
# so we pass in step + 1.
|
||||
evaluate_and_checkpoint_if_needed(step + 1, is_online=True)
|
||||
|
||||
step += 1
|
||||
online_step += 1
|
||||
|
||||
# If we're doing async rollouts, we should now wait until we've completed them before proceeding
|
||||
# to do the next batch of rollouts.
|
||||
if future.running():
|
||||
start = time.perf_counter()
|
||||
online_rollout_s, update_online_buffer_s = future.result()
|
||||
await_update_online_buffer_s = time.perf_counter() - start
|
||||
|
||||
if online_step >= cfg.training.online_steps:
|
||||
break
|
||||
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
|
||||
@@ -108,8 +108,8 @@ def visualize_dataset(
|
||||
web_port: int = 9090,
|
||||
ws_port: int = 9087,
|
||||
save: bool = False,
|
||||
output_dir: Path | None = None,
|
||||
root: Path | None = None,
|
||||
output_dir: Path | None = None,
|
||||
) -> Path | None:
|
||||
if save:
|
||||
assert (
|
||||
@@ -209,6 +209,18 @@ def main():
|
||||
required=True,
|
||||
help="Episode to visualize.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Directory path to write a .rrd file when `--save 1` is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
@@ -254,17 +266,6 @@ def main():
|
||||
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
help="Directory path to write a .rrd file when `--save 1` is set.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
help="Root directory for a dataset stored on a local machine.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
visualize_dataset(**vars(args))
|
||||
|
||||
300
lerobot/scripts/visualize_dataset_html.py
Normal file
@@ -0,0 +1,300 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
||||
|
||||
Note: The last frame of the episode doesnt always correspond to a final state.
|
||||
That's because our datasets are composed of transition from state to state up to
|
||||
the antepenultimate state associated to the ultimate action to arrive in the final state.
|
||||
However, there might not be a transition from a final state to another state.
|
||||
|
||||
Note: This script aims to visualize the data used to train the neural networks.
|
||||
~What you see is what you get~. When visualizing image modality, it is often expected to observe
|
||||
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
|
||||
save disk space. The compression factor applied has been tuned to not affect success rate.
|
||||
|
||||
Example of usage:
|
||||
|
||||
- Visualize data stored on a local machine:
|
||||
```bash
|
||||
local$ python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht
|
||||
|
||||
local$ open http://localhost:9090
|
||||
```
|
||||
|
||||
- Visualize data stored on a distant machine with a local viewer:
|
||||
```bash
|
||||
distant$ python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht
|
||||
|
||||
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
|
||||
local$ open http://localhost:9090
|
||||
```
|
||||
|
||||
- Select episodes to visualize:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--episodes 7 3 5 1 4
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from flask import Flask, redirect, render_template, url_for
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, episode_index):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frame_ids)
|
||||
|
||||
|
||||
def run_server(
|
||||
dataset: LeRobotDataset,
|
||||
episodes: list[int],
|
||||
host: str,
|
||||
port: str,
|
||||
static_folder: Path,
|
||||
template_folder: Path,
|
||||
):
|
||||
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
# home page redirects to the first episode page
|
||||
[dataset_namespace, dataset_name] = dataset.repo_id.split("/")
|
||||
first_episode_id = episodes[0]
|
||||
return redirect(
|
||||
url_for(
|
||||
"show_episode",
|
||||
dataset_namespace=dataset_namespace,
|
||||
dataset_name=dataset_name,
|
||||
episode_id=first_episode_id,
|
||||
)
|
||||
)
|
||||
|
||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id):
|
||||
dataset_info = {
|
||||
"repo_id": dataset.repo_id,
|
||||
"num_samples": dataset.num_samples,
|
||||
"num_episodes": dataset.num_episodes,
|
||||
"fps": dataset.fps,
|
||||
}
|
||||
video_paths = get_episode_video_paths(dataset, episode_id)
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": Path(video_path).name}
|
||||
for video_path in video_paths
|
||||
]
|
||||
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
|
||||
return render_template(
|
||||
"visualize_dataset_template.html",
|
||||
episode_id=episode_id,
|
||||
episodes=episodes,
|
||||
dataset_info=dataset_info,
|
||||
videos_info=videos_info,
|
||||
ep_csv_url=ep_csv_url,
|
||||
has_policy=False,
|
||||
)
|
||||
|
||||
app.run(host=host, port=port)
|
||||
|
||||
|
||||
def get_ep_csv_fname(episode_id: int):
|
||||
ep_csv_fname = f"episode_{episode_id}.csv"
|
||||
return ep_csv_fname
|
||||
|
||||
|
||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||
|
||||
has_state = "observation.state" in dataset.hf_dataset.features
|
||||
has_action = "action" in dataset.hf_dataset.features
|
||||
|
||||
# init header of csv with state and action names
|
||||
header = ["timestamp"]
|
||||
if has_state:
|
||||
dim_state = len(dataset.hf_dataset["observation.state"][0])
|
||||
header += [f"state_{i}" for i in range(dim_state)]
|
||||
if has_action:
|
||||
dim_action = len(dataset.hf_dataset["action"][0])
|
||||
header += [f"action_{i}" for i in range(dim_action)]
|
||||
|
||||
columns = ["timestamp"]
|
||||
if has_state:
|
||||
columns += ["observation.state"]
|
||||
if has_action:
|
||||
columns += ["action"]
|
||||
|
||||
rows = []
|
||||
data = dataset.hf_dataset.select_columns(columns)
|
||||
for i in range(from_idx, to_idx):
|
||||
row = [data[i]["timestamp"].item()]
|
||||
if has_state:
|
||||
row += data[i]["observation.state"].tolist()
|
||||
if has_action:
|
||||
row += data[i]["action"].tolist()
|
||||
rows.append(row)
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_dir / file_name, "w") as f:
|
||||
f.write(",".join(header) + "\n")
|
||||
for row in rows:
|
||||
row_str = [str(col) for col in row]
|
||||
f.write(",".join(row_str) + "\n")
|
||||
|
||||
|
||||
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
# get first frame of episode (hack to get video_path of the episode)
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
return [
|
||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
||||
for key in dataset.video_frame_keys
|
||||
]
|
||||
|
||||
|
||||
def visualize_dataset_html(
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
episodes: list[int] = None,
|
||||
output_dir: Path | None = None,
|
||||
serve: bool = True,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 9090,
|
||||
force_override: bool = False,
|
||||
) -> Path | None:
|
||||
init_logging()
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
if not dataset.video:
|
||||
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists():
|
||||
if force_override:
|
||||
shutil.rmtree(output_dir)
|
||||
else:
|
||||
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
||||
# so that the http server can get access to the mp4 files.
|
||||
static_dir = output_dir / "static"
|
||||
static_dir.mkdir(parents=True, exist_ok=True)
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
|
||||
|
||||
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
||||
|
||||
if episodes is None:
|
||||
episodes = list(range(dataset.num_episodes))
|
||||
|
||||
logging.info("Writing CSV files")
|
||||
for episode_index in tqdm.tqdm(episodes):
|
||||
# write states and actions in a csv (it can be slow for big datasets)
|
||||
ep_csv_fname = get_ep_csv_fname(episode_index)
|
||||
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
|
||||
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serve",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch web server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Web host used by the http server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=9090,
|
||||
help="Web port used by the http server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-override",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Delete the output directory if it exists already.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
visualize_dataset_html(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -25,7 +25,7 @@ Increase hue jitter
|
||||
```
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.hue.min_max=[-0.25,0.25]
|
||||
training.image_transforms.hue.min_max="[-0.25,0.25]"
|
||||
```
|
||||
|
||||
Increase brightness & brightness weight
|
||||
@@ -33,7 +33,7 @@ Increase brightness & brightness weight
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.brightness.weight=10.0 \
|
||||
training.image_transforms.brightness.min_max=[1.0,2.0]
|
||||
training.image_transforms.brightness.min_max="[1.0,2.0]"
|
||||
```
|
||||
|
||||
Blur images and disable saturation & hue
|
||||
@@ -41,7 +41,7 @@ Blur images and disable saturation & hue
|
||||
python lerobot/scripts/visualize_image_transforms.py \
|
||||
dataset_repo_id=lerobot/aloha_mobile_shrimp \
|
||||
training.image_transforms.sharpness.weight=10.0 \
|
||||
training.image_transforms.sharpness.min_max=[0.0,1.0] \
|
||||
training.image_transforms.sharpness.min_max="[0.0,1.0]" \
|
||||
training.image_transforms.saturation.weight=0.0 \
|
||||
training.image_transforms.hue.weight=0.0
|
||||
```
|
||||
@@ -172,4 +172,4 @@ def visualize_transforms_cli(cfg):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
visualize_transforms()
|
||||
visualize_transforms_cli()
|
||||
|
||||
360
lerobot/templates/visualize_dataset_template.html
Normal file
@@ -0,0 +1,360 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<!-- # TODO(rcadene, mishig25): store the js files locally -->
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/alpinejs/3.13.5/cdn.min.js" defer></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/dygraphs@2.2.1/dist/dygraph.min.js" type="text/javascript"></script>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<title>{{ dataset_info.repo_id }} episode {{ episode_id }}</title>
|
||||
</head>
|
||||
|
||||
<!-- Use [Alpin.js](https://alpinejs.dev), a lightweight and easy to learn JS framework -->
|
||||
<!-- Use [tailwindcss](https://tailwindcss.com/), CSS classes for styling html -->
|
||||
<!-- Use [dygraphs](https://dygraphs.com/), a lightweight JS charting library -->
|
||||
<body class="flex h-screen max-h-screen bg-slate-950 text-gray-200" x-data="createAlpineData()" @keydown.window="(e) => {
|
||||
// Use the space bar to play and pause, instead of default action (e.g. scrolling)
|
||||
const { keyCode, key } = e;
|
||||
if (keyCode === 32 || key === ' ') {
|
||||
e.preventDefault();
|
||||
$refs.btnPause.classList.contains('hidden') ? $refs.btnPlay.click() : $refs.btnPause.click();
|
||||
}else if (key === 'ArrowDown' || key === 'ArrowUp'){
|
||||
const nextEpisodeId = key === 'ArrowDown' ? {{ episode_id }} + 1 : {{ episode_id }} - 1;
|
||||
const lowestEpisodeId = {{ episodes }}.at(0);
|
||||
const highestEpisodeId = {{ episodes }}.at(-1);
|
||||
if(nextEpisodeId >= lowestEpisodeId && nextEpisodeId <= highestEpisodeId){
|
||||
window.location.href = `./episode_${nextEpisodeId}`;
|
||||
}
|
||||
}
|
||||
}">
|
||||
<!-- Sidebar -->
|
||||
<div x-ref="sidebar" class="w-60 bg-slate-900 p-5 break-words max-h-screen overflow-y-auto">
|
||||
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
|
||||
|
||||
<ul>
|
||||
<li>
|
||||
Number of samples/frames: {{ dataset_info.num_samples }}
|
||||
</li>
|
||||
<li>
|
||||
Number of episodes: {{ dataset_info.num_episodes }}
|
||||
</li>
|
||||
<li>
|
||||
Frames per second: {{ dataset_info.fps }}
|
||||
</li>
|
||||
</ul>
|
||||
|
||||
<p>Episodes:</p>
|
||||
<ul class="ml-2">
|
||||
{% for episode in episodes %}
|
||||
<li class="font-mono text-sm mt-0.5">
|
||||
<a href="episode_{{ episode }}" class="underline {% if episode_id == episode %}font-bold -ml-1{% endif %}">
|
||||
Episode {{ episode }}
|
||||
</a>
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
|
||||
<!-- Toggle sidebar button -->
|
||||
<button class="flex items-center opacity-50 hover:opacity-100 mx-1"
|
||||
@click="() => ($refs.sidebar.classList.toggle('hidden'))" title="Toggle sidebar">
|
||||
<div class="bg-slate-500 w-2 h-10 rounded-full"></div>
|
||||
</button>
|
||||
|
||||
<!-- Content -->
|
||||
<div class="flex-1 max-h-screen flex flex-col gap-4 overflow-y-auto">
|
||||
<h1 class="text-xl font-bold mt-4 font-mono">
|
||||
Episode {{ episode_id }}
|
||||
</h1>
|
||||
|
||||
<!-- Videos -->
|
||||
<div class="flex flex-wrap gap-1">
|
||||
{% for video_info in videos_info %}
|
||||
<div class="max-w-96">
|
||||
<p class="text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
|
||||
<video autoplay muted loop type="video/mp4" class="min-w-64" @timeupdate="() => {
|
||||
if (video.duration) {
|
||||
const time = video.currentTime;
|
||||
const pc = (100 / video.duration) * time;
|
||||
$refs.slider.value = pc;
|
||||
dygraphTime = time;
|
||||
dygraphIndex = Math.floor(pc * dygraph.numRows() / 100);
|
||||
dygraph.setSelection(dygraphIndex, undefined, true, true);
|
||||
|
||||
$refs.timer.textContent = formatTime(time) + ' / ' + formatTime(video.duration);
|
||||
|
||||
updateTimeQuery(time.toFixed(2));
|
||||
}
|
||||
}" @ended="() => {
|
||||
$refs.btnPlay.classList.remove('hidden');
|
||||
$refs.btnPause.classList.add('hidden');
|
||||
}"
|
||||
@loadedmetadata="() => ($refs.timer.textContent = formatTime(0) + ' / ' + formatTime(video.duration))">
|
||||
<source src="{{ video_info.url }}">
|
||||
Your browser does not support the video tag.
|
||||
</video>
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
|
||||
<!-- Shortcuts info -->
|
||||
<div class="text-sm hidden md:block">
|
||||
Hotkeys: <span class="font-mono">Space</span> to pause/unpause, <span class="font-mono">Arrow Down</span> to go to next episode, <span class="font-mono">Arrow Up</span> to go to previous episode.
|
||||
</div>
|
||||
|
||||
<!-- Controllers -->
|
||||
<div class="flex gap-1 text-3xl items-center">
|
||||
<button x-ref="btnPlay" class="-rotate-90 hidden" class="-rotate-90" title="Play. Toggle with Space" @click="() => {
|
||||
videos.forEach(video => video.play());
|
||||
$refs.btnPlay.classList.toggle('hidden');
|
||||
$refs.btnPause.classList.toggle('hidden');
|
||||
}">🔽</button>
|
||||
<button x-ref="btnPause" title="Pause. Toggle with Space" @click="() => {
|
||||
videos.forEach(video => video.pause());
|
||||
$refs.btnPlay.classList.toggle('hidden');
|
||||
$refs.btnPause.classList.toggle('hidden');
|
||||
}">⏸️</button>
|
||||
<button title="Jump backward 5 seconds"
|
||||
@click="() => (videos.forEach(video => (video.currentTime -= 5)))">⏪</button>
|
||||
<button title="Jump forward 5 seconds"
|
||||
@click="() => (videos.forEach(video => (video.currentTime += 5)))">⏩</button>
|
||||
<button title="Rewind from start"
|
||||
@click="() => (videos.forEach(video => (video.currentTime = 0.0)))">↩️</button>
|
||||
<input x-ref="slider" max="100" min="0" step="1" type="range" value="0" class="w-80 mx-2" @input="() => {
|
||||
const sliderValue = $refs.slider.value;
|
||||
$refs.btnPause.click();
|
||||
videos.forEach(video => {
|
||||
const time = (video.duration * sliderValue) / 100;
|
||||
video.currentTime = time;
|
||||
});
|
||||
}" />
|
||||
<div x-ref="timer" class="font-mono text-sm border border-slate-500 rounded-lg px-1 py-0.5 shrink-0">0:00 /
|
||||
0:00
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Graph -->
|
||||
<div class="flex gap-2 mb-4 flex-wrap">
|
||||
<div>
|
||||
<div id="graph" @mouseleave="() => {
|
||||
dygraph.setSelection(dygraphIndex, undefined, true, true);
|
||||
dygraphTime = video.currentTime;
|
||||
}">
|
||||
</div>
|
||||
<p x-ref="graphTimer" class="font-mono ml-14 mt-4"
|
||||
x-init="$watch('dygraphTime', value => ($refs.graphTimer.innerText = `Time: ${dygraphTime.toFixed(2)}s`))">
|
||||
Time: 0.00s
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<table class="text-sm border-collapse border border-slate-700" x-show="currentFrameData">
|
||||
<thead>
|
||||
<tr>
|
||||
<th></th>
|
||||
<template x-for="(_, colIndex) in Array.from({length: nColumns}, (_, index) => index)">
|
||||
<th class="border border-slate-700">
|
||||
<div class="flex gap-x-2 justify-between px-2">
|
||||
<input type="checkbox" :checked="isColumnChecked(colIndex)"
|
||||
@change="toggleColumn(colIndex)">
|
||||
<p x-text="`${columnNames[colIndex]}`"></p>
|
||||
</div>
|
||||
</th>
|
||||
</template>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<template x-for="(row, rowIndex) in rows">
|
||||
<tr class="odd:bg-gray-800 even:bg-gray-900">
|
||||
<td class="border border-slate-700">
|
||||
<div class="flex gap-x-2 w-24 font-semibold px-1">
|
||||
<input type="checkbox" :checked="isRowChecked(rowIndex)"
|
||||
@change="toggleRow(rowIndex)">
|
||||
<p x-text="`Motor ${rowIndex}`"></p>
|
||||
</div>
|
||||
</td>
|
||||
<template x-for="(cell, colIndex) in row">
|
||||
<td x-show="cell" class="border border-slate-700">
|
||||
<div class="flex gap-x-2 w-24 justify-between px-2">
|
||||
<input type="checkbox" x-model="cell.checked" @change="updateTableValues()">
|
||||
<span x-text="`${cell.value.toFixed(2)}`"
|
||||
:style="`color: ${cell.color}`"></span>
|
||||
</div>
|
||||
</td>
|
||||
</template>
|
||||
</tr>
|
||||
</template>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<div id="labels" class="hidden">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function createAlpineData() {
|
||||
return {
|
||||
// state
|
||||
dygraph: null,
|
||||
currentFrameData: null,
|
||||
columnNames: ["state", "action", "pred action"],
|
||||
nColumns: {% if has_policy %}3{% else %}2{% endif %},
|
||||
checked: [],
|
||||
dygraphTime: 0.0,
|
||||
dygraphIndex: 0,
|
||||
videos: null,
|
||||
video: null,
|
||||
colors: null,
|
||||
|
||||
// alpine initialization
|
||||
init() {
|
||||
this.videos = document.querySelectorAll('video');
|
||||
this.video = this.videos[0];
|
||||
this.dygraph = new Dygraph(document.getElementById("graph"), '{{ ep_csv_url }}', {
|
||||
pixelsPerPoint: 0.01,
|
||||
legend: 'always',
|
||||
labelsDiv: document.getElementById('labels'),
|
||||
labelsKMB: true,
|
||||
strokeWidth: 1.5,
|
||||
pointClickCallback: (event, point) => {
|
||||
this.dygraphTime = point.xval;
|
||||
this.updateTableValues(this.dygraphTime);
|
||||
},
|
||||
highlightCallback: (event, x, points, row, seriesName) => {
|
||||
this.dygraphTime = x;
|
||||
this.updateTableValues(this.dygraphTime);
|
||||
},
|
||||
drawCallback: (dygraph, is_initial) => {
|
||||
if (is_initial) {
|
||||
// dygraph initialization
|
||||
this.dygraph.setSelection(this.dygraphIndex, undefined, true, true);
|
||||
this.colors = this.dygraph.getColors();
|
||||
this.checked = Array(this.colors.length).fill(true);
|
||||
|
||||
const seriesNames = this.dygraph.getLabels().slice(1);
|
||||
const colors = [];
|
||||
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
|
||||
let lightnessIdx = 0;
|
||||
const chunkSize = Math.ceil(seriesNames.length / this.nColumns);
|
||||
for (let i = 0; i < seriesNames.length; i += chunkSize) {
|
||||
const lightness = LIGHTNESS[lightnessIdx];
|
||||
for (let hue = 0; hue < 360; hue += parseInt(360/chunkSize)) {
|
||||
const color = `hsl(${hue}, 100%, ${lightness}%)`;
|
||||
colors.push(color);
|
||||
}
|
||||
lightnessIdx += 1;
|
||||
}
|
||||
this.dygraph.updateOptions({ colors });
|
||||
this.colors = colors;
|
||||
|
||||
this.updateTableValues();
|
||||
|
||||
let url = new URL(window.location.href);
|
||||
let params = new URLSearchParams(url.search);
|
||||
let time = params.get("t");
|
||||
if(time){
|
||||
time = parseFloat(time);
|
||||
this.videos.forEach(video => (video.currentTime = time));
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
},
|
||||
|
||||
//#region Table Data
|
||||
|
||||
// turn dygraph's 1D data (at a given time t) to 2D data that whose columns names are defined in this.columnNames.
|
||||
// 2d data view is used to create html table element.
|
||||
get rows() {
|
||||
if (!this.currentFrameData) {
|
||||
return [];
|
||||
}
|
||||
const columnSize = Math.ceil(this.currentFrameData.length / this.nColumns);
|
||||
return Array.from({
|
||||
length: columnSize
|
||||
}, (_, rowIndex) => {
|
||||
const row = [
|
||||
this.currentFrameData[rowIndex] || null,
|
||||
this.currentFrameData[rowIndex + columnSize] || null,
|
||||
];
|
||||
if (this.nColumns === 3) {
|
||||
row.push(this.currentFrameData[rowIndex + 2 * columnSize] || null)
|
||||
}
|
||||
return row;
|
||||
});
|
||||
},
|
||||
isRowChecked(rowIndex) {
|
||||
return this.rows[rowIndex].every(cell => cell && cell.checked);
|
||||
},
|
||||
isColumnChecked(colIndex) {
|
||||
return this.rows.every(row => row[colIndex] && row[colIndex].checked);
|
||||
},
|
||||
toggleRow(rowIndex) {
|
||||
const newState = !this.isRowChecked(rowIndex);
|
||||
this.rows[rowIndex].forEach(cell => {
|
||||
if (cell) cell.checked = newState;
|
||||
});
|
||||
this.updateTableValues();
|
||||
},
|
||||
toggleColumn(colIndex) {
|
||||
const newState = !this.isColumnChecked(colIndex);
|
||||
this.rows.forEach(row => {
|
||||
if (row[colIndex]) row[colIndex].checked = newState;
|
||||
});
|
||||
this.updateTableValues();
|
||||
},
|
||||
|
||||
// given time t, update the values in the html table with "data[t]"
|
||||
updateTableValues(time) {
|
||||
if (!this.colors) {
|
||||
return;
|
||||
}
|
||||
let pc = (100 / this.video.duration) * (time === undefined ? this.video.currentTime : time);
|
||||
if (isNaN(pc)) pc = 0;
|
||||
const index = Math.floor(pc * this.dygraph.numRows() / 100);
|
||||
// slice(1) to remove the timestamp point that we do not need
|
||||
const labels = this.dygraph.getLabels().slice(1);
|
||||
const values = this.dygraph.rawData_[index].slice(1);
|
||||
const checkedNew = this.currentFrameData ? this.currentFrameData.map(cell => cell.checked) : Array(
|
||||
this.colors.length).fill(true);
|
||||
this.currentFrameData = labels.map((label, idx) => ({
|
||||
label,
|
||||
value: values[idx],
|
||||
color: this.colors[idx],
|
||||
checked: checkedNew[idx],
|
||||
}));
|
||||
const shouldUpdateVisibility = !this.checked.every((value, index) => value === checkedNew[index]);
|
||||
if (shouldUpdateVisibility) {
|
||||
this.checked = checkedNew;
|
||||
this.dygraph.setVisibility(this.checked);
|
||||
}
|
||||
},
|
||||
|
||||
//#endregion
|
||||
|
||||
updateTimeQuery(time) {
|
||||
let url = new URL(window.location.href);
|
||||
let params = new URLSearchParams(url.search);
|
||||
params.set("t", time);
|
||||
url.search = params.toString();
|
||||
window.history.replaceState({}, '', url.toString());
|
||||
},
|
||||
|
||||
|
||||
formatTime(time) {
|
||||
var hours = Math.floor(time / 3600);
|
||||
var minutes = Math.floor((time % 3600) / 60);
|
||||
var seconds = Math.floor(time % 60);
|
||||
return (hours > 0 ? hours + ':' : '') + (minutes < 10 ? '0' + minutes : minutes) + ':' + (seconds <
|
||||
10 ?
|
||||
'0' + seconds : seconds);
|
||||
}
|
||||
};
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
BIN
media/koch/follower_rest.webp
Normal file
|
After Width: | Height: | Size: 326 KiB |
BIN
media/koch/follower_rotated.webp
Normal file
|
After Width: | Height: | Size: 312 KiB |
BIN
media/koch/follower_zero.webp
Normal file
|
After Width: | Height: | Size: 469 KiB |
BIN
media/koch/leader_rest.webp
Normal file
|
After Width: | Height: | Size: 339 KiB |
BIN
media/koch/leader_rotated.webp
Normal file
|
After Width: | Height: | Size: 232 KiB |
BIN
media/koch/leader_zero.webp
Normal file
|
After Width: | Height: | Size: 484 KiB |
BIN
media/tutorial/koch_v1_1_leader_follower.webp
Normal file
|
After Width: | Height: | Size: 58 KiB |
BIN
media/tutorial/visualize_dataset_html.webp
Normal file
|
After Width: | Height: | Size: 121 KiB |
772
poetry.lock
generated
@@ -38,12 +38,12 @@ einops = ">=0.8.0"
|
||||
pymunk = ">=6.6.0"
|
||||
zarr = ">=2.17.0"
|
||||
numba = ">=0.59.0"
|
||||
torch = "^2.2.1"
|
||||
torch = ">=2.2.1"
|
||||
opencv-python = ">=4.9.0"
|
||||
diffusers = "^0.27.2"
|
||||
diffusers = ">=0.27.2"
|
||||
torchvision = ">=0.17.1"
|
||||
h5py = ">=3.10.0"
|
||||
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
|
||||
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.23.0"}
|
||||
gymnasium = ">=0.29.1"
|
||||
cmake = ">=3.29.0.1"
|
||||
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
|
||||
@@ -54,15 +54,19 @@ pre-commit = {version = ">=3.7.0", optional = true}
|
||||
debugpy = {version = ">=1.8.1", optional = true}
|
||||
pytest = {version = ">=8.1.0", optional = true}
|
||||
pytest-cov = {version = ">=5.0.0", optional = true}
|
||||
datasets = "^2.19.0"
|
||||
datasets = ">=2.19.0"
|
||||
imagecodecs = { version = ">=2024.1.1", optional = true }
|
||||
pyav = ">=12.0.5"
|
||||
moviepy = ">=1.0.3"
|
||||
rerun-sdk = ">=0.15.1"
|
||||
deepdiff = ">=7.0.1"
|
||||
scikit-image = {version = "^0.23.2", optional = true}
|
||||
pandas = {version = "^2.2.2", optional = true}
|
||||
pytest-mock = {version = "^3.14.0", optional = true}
|
||||
flask = ">=3.0.3"
|
||||
pandas = {version = ">=2.2.2", optional = true}
|
||||
scikit-image = {version = ">=0.23.2", optional = true}
|
||||
dynamixel-sdk = {version = ">=3.7.31", optional = true}
|
||||
pynput = {version = ">=1.7.7", optional = true}
|
||||
# TODO(rcadene, salibert): 71.0.1 has a bug
|
||||
setuptools = {version = "!=71.0.1", optional = true}
|
||||
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
@@ -71,9 +75,10 @@ pusht = ["gym-pusht"]
|
||||
xarm = ["gym-xarm"]
|
||||
aloha = ["gym-aloha"]
|
||||
dev = ["pre-commit", "debugpy"]
|
||||
test = ["pytest", "pytest-cov", "pytest-mock"]
|
||||
test = ["pytest", "pytest-cov"]
|
||||
umi = ["imagecodecs"]
|
||||
video_benchmark = ["scikit-image", "pandas"]
|
||||
koch = ["dynamixel-sdk", "pynput"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
@@ -106,7 +111,6 @@ exclude = [
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||
ignore-init-module-imports = true
|
||||
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -13,8 +13,28 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .utils import DEVICE
|
||||
import pytest
|
||||
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
from .utils import DEVICE, KOCH_ROBOT_CONFIG_PATH
|
||||
|
||||
|
||||
def pytest_collection_finish():
|
||||
print(f"\nTesting with {DEVICE=}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def is_koch_available():
|
||||
try:
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
|
||||
robot_cfg = init_hydra_config(KOCH_ROBOT_CONFIG_PATH)
|
||||
robot = make_robot(robot_cfg)
|
||||
robot.connect()
|
||||
del robot
|
||||
return True
|
||||
except Exception as e:
|
||||
print("A koch robot is not available.")
|
||||
print(e)
|
||||
return False
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9f9347c8d9ac90ee44e6dd86f65043438168df6bbe4bab2d2b875e55ef7376ef
|
||||
size 1488
|
||||
oid sha256:7841afb9ef99c0601448c43a20c25eb029440c73816319c67c5d7e1c5cde2445
|
||||
size 136
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
||||
size 33
|
||||
oid sha256:50e40e4c2bb523fca0b54e9a9635281312e9c6f9d757db03c06a0865c5508f29
|
||||
size 188
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:02fc4ea25766269f65752a60b0594c43d799b0ae528cd773bf024b064b5aa329
|
||||
oid sha256:03508d82db846a804aef1a28aec3cb9572e3105b55a02b6ddbb09b2522d57b84
|
||||
size 4344
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:55d7b1a06fe3e3051482752740074348bdb5fc98fb2e305b06d6203994117b27
|
||||
oid sha256:7009b3d2f14d6af497eeb32a52332e79cb9c07db24a6c2bbfbeffbaa8151dd69
|
||||
size 592448
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:98329e4b40e9be0d63f7d36da9d86c44bbe7eeeb1b10d3ba973c923f3be70867
|
||||
oid sha256:34ece24fb6b302db0b68987858509f31713fb299faa9a9d34b8fd68f10bc3100
|
||||
size 247
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:54e42cdfd016a0ced2ab1fe2966a8c15a2384e0dbe1a2fe87433a2d1b8209ac0
|
||||
size 5220057
|
||||
oid sha256:a70cc17019407cf6bee44fa2c78b4f29e48eb1696aa1a4ff4c048ba256574523
|
||||
size 6356921
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:af1ded2a244cb47a96255b75f584a643edf6967e13bb5464b330ffdd9d7ad859
|
||||
size 5284692
|
||||
oid sha256:2b35992036e6dcee7d4df6d1675d55d1dd2d658b2d65442737e709895699a2f0
|
||||
size 5084448
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:13d1bebabd79984fd6715971be758ef9a354495adea5e8d33f4e7904365e112b
|
||||
size 5258380
|
||||
oid sha256:3aa92e6b6bd0e39f6de530ea6a270671db7350cdc101c9d9030c775539c708c1
|
||||
size 5441406
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f33bc6810f0b91817a42610364cb49ed1b99660f058f0f9407e6f5920d0aee02
|
||||
size 1008
|
||||
oid sha256:4ee862b1a6dc1d11df77c36c47ea00db88ad35a48e4d71c2940ad26b55fe2167
|
||||
size 136
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
||||
size 33
|
||||
oid sha256:50e40e4c2bb523fca0b54e9a9635281312e9c6f9d757db03c06a0865c5508f29
|
||||
size 188
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7b58d6c89e936a781a307805ebecf0dd473fbc02d52a7094da62e54bffb9454a
|
||||
oid sha256:095c30bfe3c5da168c85aceef905e74e2142866332282965aa6812f6e6e48448
|
||||
size 4344
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a08be578285cbe2d35b78f150d464ff3e10604a9865398c976983e0d711774f9
|
||||
oid sha256:98859f2d87e1a0abb9a930a82af623504b3efb26f70fe576f05bab7f19024427
|
||||
size 788528
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:34e36233477c8aa0b0840314ddace072062d4f486d06546bbd6550832c370065
|
||||
oid sha256:38cf4116a65cb92a5c43f9b9da7a7b81cfa9168b17605c8c456f7d3a3a23b77a
|
||||
size 247
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:66e7349a4a82ca6042a7189608d01eb1cfa38d100d039b5445ae1a9e65d824ab
|
||||
size 14470946
|
||||
oid sha256:596dda720d378a44b6b61a6a72b44bec3e55e85198bca37f9dace6fe84af7ff0
|
||||
size 16062396
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a2146f0c10c9f2611e57e617983aa4f91ad681b4fc50d91b992b97abd684f926
|
||||
size 11662185
|
||||
oid sha256:c614bbaf93d65354a82001b357682a0bd36f9603685f6c735c5e377b763d0bdb
|
||||
size 10317415
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5affbaf1c48895ba3c626e0d8cf1309e5f4ec6bbaa135313096f52a22de66c05
|
||||
size 11410342
|
||||
oid sha256:868788028a38334b6b566cb17ffcc2ace2ec2b2b68ff2a58b6d29eb3c3e2ec1f
|
||||
size 9516445
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6c2b195ca91b88fd16422128d386d2cabd808a1862c6d127e6bf2e83e1fe819a
|
||||
size 448
|
||||
oid sha256:f365a02b052a2697b1558f4ab9b813f0d4ba46a5bc6ae3da30bbc4b135426aa6
|
||||
size 136
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cf148247bf191c7f7e8af738a7b9e147f9ffffeec0e4b9d1c4783c4e384da7eb
|
||||
size 33
|
||||
oid sha256:50e40e4c2bb523fca0b54e9a9635281312e9c6f9d757db03c06a0865c5508f29
|
||||
size 188
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b360b6b956d2adcb20589947c553348ef1eb6b70743c989dcbe95243d8592ce5
|
||||
oid sha256:5c96f47b569b7af82e05200213d733626664150aa7c5ae3298fd04a2138a2023
|
||||
size 4344
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3f5c3926b4d4da9271abefcdf6a8952bb1f13258a9c39fe0fd223f548dc89dcb
|
||||
oid sha256:75f53d221827f17cc2ded3908452e24331b39b79dc3a26f2b9d89a6e6894baab
|
||||
size 887728
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4993b05fb026619eec5eb70db8cadaa041ba4ab92d38b4a387167ace03b1018b
|
||||
oid sha256:d394d451929b805f2d94f9fc5b12d15c31cfc494df76d7d642b63378b8ba0131
|
||||
size 247
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bd25d17ef5b7500386761b5e32920879bbdcafe0e17a8a8845628525d861e644
|
||||
size 10231081
|
||||
oid sha256:73ddb898f83589b4bcabe978e46e75f20be215492f115bf6ebc98f1d01e1eff8
|
||||
size 9696507
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5b557acbfeb0681c0a38e47263d945f6cd3a03461298d8b17209c81e3fd0aae8
|
||||
size 9701371
|
||||
oid sha256:d3d993977bee96882732d4a9c9d082c356fc9fcd8199c027b016207d60494c2f
|
||||
size 8957007
|
||||
|
||||