Compare commits

...

5 Commits

Author SHA1 Message Date
Remi Cadene
33d149000a WIP faster act 2024-06-09 11:57:06 +00:00
Remi Cadene
b65247feee Add mobile and neck 2024-06-06 14:01:14 +00:00
Remi Cadene
5e85a2c50b Add reachy2 dataset, policy, env 2024-06-04 12:31:59 +00:00
Remi Cadene
a56626cf9c Add custom visualize_dataset.py 2024-06-03 15:47:12 +00:00
Remi Cadene
44ba4ed566 Fix aloha (WIP: do not train in sim) 2024-06-03 14:47:06 +00:00
14 changed files with 1055 additions and 184 deletions

View File

@@ -28,7 +28,7 @@ training:
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 50

View File

@@ -0,0 +1,189 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains utilities to process raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act
"""
import gc
import re
import shutil
from pathlib import Path
import h5py
import torch
import tqdm
from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame
def get_cameras(hdf5_data):
# ignore depth channel, not currently handled
# TODO(rcadene): add depth
rgb_cameras = [key for key in hdf5_data["/observations/images_ids"].keys() if "depth" not in key] # noqa: SIM118
return rgb_cameras
def check_format(raw_dir) -> bool:
hdf5_paths = list(raw_dir.glob("episode_*.hdf5"))
assert len(hdf5_paths) != 0
for hdf5_path in hdf5_paths:
with h5py.File(hdf5_path, "r") as data:
assert "/action" in data
assert "/observations/qpos" in data
assert data["/action"].ndim == 2
assert data["/observations/qpos"].ndim == 2
num_frames = data["/action"].shape[0]
assert num_frames == data["/observations/qpos"].shape[0]
for camera in get_cameras(data):
assert num_frames == data[f"/observations/images_ids/{camera}"].shape[0]
assert (raw_dir / hdf5_path.name.replace(".hdf5", f"_{camera}.mp4")).exists()
# assert data[f"/observations/images_ids/{camera}"].ndim == 4
# b, h, w, c = data[f"/observations/images_ids/{camera}"].shape
# assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
def load_from_raw(raw_dir, out_dir, fps, video, debug):
hdf5_files = list(raw_dir.glob("*.hdf5"))
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)):
match = re.search(r"_(\d+).hdf5", ep_path.name)
if not match:
raise ValueError(ep_path.name)
raw_ep_idx = int(match.group(1))
with h5py.File(ep_path, "r") as ep:
num_frames = ep["/action"].shape[0]
# last step of demonstration is considered done
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
if "/observations/qvel" in ep:
velocity = torch.from_numpy(ep["/observations/qvel"][:])
if "/observations/effort" in ep:
effort = torch.from_numpy(ep["/observations/effort"][:])
ep_dict = {}
videos_dir = out_dir / "videos"
videos_dir.mkdir(parents=True, exist_ok=True)
for camera in get_cameras(ep):
img_key = f"observation.images.{camera}"
raw_fname = f"episode_{raw_ep_idx}_{camera}.mp4"
new_fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
shutil.copy(str(raw_dir / raw_fname), str(videos_dir / new_fname))
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{new_fname}", "timestamp": i / fps} for i in range(num_frames)
]
ep_dict["observation.state"] = state
if "/observations/velocity" in ep:
ep_dict["observation.velocity"] = velocity
if "/observations/effort" in ep:
ep_dict["observation.effort"] = effort
ep_dict["action"] = action
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["next.done"] = done
# TODO(rcadene): add reward and success by computing them in sim
assert isinstance(ep_idx, int)
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
gc.collect()
# process first episode only
if debug:
break
data_dict = concatenate_episodes(ep_dicts)
return data_dict, episode_data_index
def to_hf_dataset(data_dict, video) -> Dataset:
features = {}
keys = [key for key in data_dict if "observation.images." in key]
for key in keys:
if video:
features[key] = VideoFrame()
else:
features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
)
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
features["next.done"] = Value(dtype="bool", id=None)
features["index"] = Value(dtype="int64", id=None)
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
# sanity check
check_format(raw_dir)
if fps is None:
fps = 30
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
hf_dataset = to_hf_dataset(data_dir, video)
info = {
"fps": fps,
"video": video,
}
return hf_dataset, episode_data_index, info

View File

@@ -233,9 +233,6 @@ class Logger:
if self._wandb is not None:
for k, v in d.items():
if not isinstance(v, (int, float, str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step)

View File

@@ -139,25 +139,26 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
bsize = actions_hat.shape[0]
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
out_dict = {}
out_dict["l1_loss"] = l1_loss
loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
else:
loss_dict["loss"] = l1_loss
out_dict["loss"] = l1_loss
return loss_dict
out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"]
return out_dict
class ACT(nn.Module):
@@ -264,6 +265,16 @@ class ACT(nn.Module):
self._reset_parameters()
self.register_buffer(
"latent_sample",
torch.zeros(1, config.latent_dim, dtype=torch.float32),
)
self.register_buffer(
"decoder_in",
torch.zeros(config.chunk_size, 1, config.dim_model, dtype=torch.float32),
)
def _reset_parameters(self):
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
for p in chain(self.encoder.parameters(), self.decoder.parameters()):
@@ -328,10 +339,7 @@ class ACT(nn.Module):
else:
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
latent_sample = self.latent_sample
# Prepare all other transformer encoder inputs.
# Camera observation features and positional embeddings.
@@ -341,8 +349,7 @@ class ACT(nn.Module):
for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
@@ -373,12 +380,7 @@ class ACT(nn.Module):
# Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
decoder_in = torch.zeros(
(self.config.chunk_size, batch_size, self.config.dim_model),
dtype=pos_embed.dtype,
device=pos_embed.device,
)
decoder_in = self.decoder_in
decoder_out = self.decoder(
decoder_in,
encoder_out,
@@ -578,6 +580,10 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
self._eps = 1e-6
# Inverse "common ratio" for the geometric progression in sinusoid frequencies.
self._temperature = 10000
self.register_buffer(
"inverse_frequency",
self._temperature ** (2 * (torch.arange(self.dimension, dtype=torch.float32) // 2) / self.dimension),
)
def forward(self, x: Tensor) -> Tensor:
"""
@@ -589,8 +595,8 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
# they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
y_range = not_mask.cumsum(1, dtype=torch.float32)
x_range = not_mask.cumsum(2, dtype=torch.float32)
y_range = not_mask.cumsum(1, dtype=x.dtype)
x_range = not_mask.cumsum(2, dtype=x.dtype)
# "Normalize" the position index such that it ranges in [0, 2π].
# Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
@@ -598,9 +604,7 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
inverse_frequency = self._temperature ** (
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
)
inverse_frequency = self.inverse_frequency
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)

View File

@@ -0,0 +1,13 @@
# @package _global_
fps: 30
env:
name: dora
task: DoraReachy2-v0
state_dim: 22
action_dim: 22
fps: ${fps}
episode_length: 400
gym:
fps: ${fps}

View File

@@ -25,7 +25,7 @@ training:
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 50

View File

@@ -0,0 +1,97 @@
# @package _global_
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# env=dora_aloha_real
# ```
seed: 1000
dataset_repo_id: cadene/reachy2_teleop_remi
override_dataset_stats:
observation.images.cam_trunk:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 80000
online_steps: 0
eval_freq: -1
save_freq: 10000
log_freq: 100
save_checkpoint: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 50
batch_size: 50
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.cam_trunk: [3, 800, 1280]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.cam_trunk: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@@ -51,7 +51,7 @@ training:
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 50

View File

@@ -49,7 +49,7 @@ training:
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 50

View File

@@ -86,6 +86,8 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
elif raw_format == "aloha_dora":
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format
elif raw_format == "reachy2_hdf5":
from lerobot.common.datasets.push_dataset_to_hub.reachy2_hdf5_format import from_raw_to_lerobot_format
elif raw_format == "xarm_pkl":
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
else:

View File

@@ -107,7 +107,7 @@ def update_policy(
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"]
loss = output_dict["loss"].mean()
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.

View File

@@ -30,48 +30,46 @@ Examples:
- Visualize data stored on a local machine:
```
local$ python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0
--repo-id lerobot/pusht
local$ open http://localhost:9090
```
- Visualize data stored on a distant machine with a local viewer:
```
distant$ python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
local$ open http://localhost:9090
```
- Select episodes to visualize:
```
python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0 \
--save 1 \
--output-dir path/to/directory
local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
local$ rerun lerobot_pusht_episode_0.rrd
--episode-indices 7 3 5 1 4
```
- Visualize data stored on a distant machine through streaming:
(You need to forward the websocket port to the distant machine, with
`ssh -L 9087:localhost:9087 username@remote-host`)
```
distant$ python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0 \
--mode distant \
--ws-port 9087
local$ rerun ws://localhost:9087
```
"""
import argparse
import gc
import http.server
import logging
import time
import os
import shutil
import socketserver
from pathlib import Path
import rerun as rr
import torch
import tqdm
import yaml
from bs4 import BeautifulSoup
from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.utils.utils import init_logging
class EpisodeSampler(torch.utils.data.Sampler):
@@ -87,33 +85,307 @@ class EpisodeSampler(torch.utils.data.Sampler):
return len(self.frame_ids)
def to_hwc_uint8_numpy(chw_float32_torch):
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
return hwc_uint8_numpy
class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
def end_headers(self):
self.send_header("Cache-Control", "no-store, no-cache, must-revalidate")
self.send_header("Pragma", "no-cache")
self.send_header("Expires", "0")
super().end_headers()
def visualize_dataset(
repo_id: str,
episode_index: int,
batch_size: int = 32,
num_workers: int = 0,
mode: str = "local",
web_port: int = 9090,
ws_port: int = 9087,
save: bool = False,
output_dir: Path | None = None,
) -> Path | None:
if save:
assert (
output_dir is not None
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
def run_server(path, port):
# Change directory to serve 'index.html` as front page
os.chdir(path)
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id)
with socketserver.TCPServer(("", port), NoCacheHTTPRequestHandler) as httpd:
logging.info(f"Serving HTTP on 0.0.0.0 port {port} (http://0.0.0.0:{port}/) ...")
httpd.serve_forever()
def create_html_page(page_title: str):
"""Create a html page with beautiful soop with default doctype, meta, header and title."""
soup = BeautifulSoup("", "html.parser")
doctype = soup.new_tag("!DOCTYPE html")
soup.append(doctype)
html = soup.new_tag("html", lang="en")
soup.append(html)
head = soup.new_tag("head")
html.append(head)
meta_charset = soup.new_tag("meta", charset="UTF-8")
head.append(meta_charset)
meta_viewport = soup.new_tag(
"meta", attrs={"name": "viewport", "content": "width=device-width, initial-scale=1.0"}
)
head.append(meta_viewport)
title = soup.new_tag("title")
title.string = page_title
head.append(title)
body = soup.new_tag("body")
html.append(body)
main_div = soup.new_tag("div")
body.append(main_div)
return soup, head, body
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None):
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
has_state = "observation.state" in dataset.hf_dataset.features
has_action = "action" in dataset.hf_dataset.features
has_inference = inference_results is not None
# init header of csv with state and action names
header = ["timestamp"]
if has_state:
dim_state = len(dataset.hf_dataset["observation.state"][0])
header += [f"state_{i}" for i in range(dim_state)]
if has_action:
dim_action = len(dataset.hf_dataset["action"][0])
header += [f"action_{i}" for i in range(dim_action)]
if has_inference:
assert "actions" in inference_results
assert "loss" in inference_results
dim_pred_action = inference_results["actions"].shape[2]
header += [f"pred_action_{i}" for i in range(dim_pred_action)]
header += ["loss"]
columns = ["timestamp"]
if has_state:
columns += ["observation.state"]
if has_action:
columns += ["action"]
rows = []
data = dataset.hf_dataset.select_columns(columns)
for i in range(from_idx, to_idx):
row = [data[i]["timestamp"].item()]
if has_state:
row += data[i]["observation.state"].tolist()
if has_action:
row += data[i]["action"].tolist()
rows.append(row)
if has_inference:
num_frames = len(rows)
assert num_frames == inference_results["actions"].shape[0]
assert num_frames == inference_results["loss"].shape[0]
for i in range(num_frames):
rows[i] += inference_results["actions"][i, 0].tolist()
rows[i] += [inference_results["loss"][i].item()]
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w") as f:
f.write(",".join(header) + "\n")
for row in rows:
row_str = [str(col) for col in row]
f.write(",".join(row_str) + "\n")
def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset):
"""Write a javascript file containing logic to synchronize camera feeds and timeseries."""
s = ""
s += "document.addEventListener('DOMContentLoaded', function () {\n"
for i, key in enumerate(dataset.video_frame_keys):
s += f" const video{i} = document.getElementById('video_{key}');\n"
s += " const slider = document.getElementById('videoControl');\n"
s += " const playButton = document.getElementById('playButton');\n"
s += f" const dygraph = new Dygraph(document.getElementById('graph'), '{ep_csv_fname}', " + "{\n"
s += " pixelsPerPoint: 0.01,\n"
s += " legend: 'always',\n"
s += " labelsDiv: document.getElementById('labels'),\n"
s += " labelsSeparateLines: true,\n"
s += " labelsKMB: true,\n"
s += " highlightCircleSize: 1.5,\n"
s += " highlightSeriesOpts: {\n"
s += " strokeWidth: 1.5,\n"
s += " strokeBorderWidth: 1,\n"
s += " highlightCircleSize: 3\n"
s += " }\n"
s += " });\n"
s += "\n"
s += " // Function to play both videos\n"
s += " playButton.addEventListener('click', function () {\n"
for i in range(len(dataset.video_frame_keys)):
s += f" video{i}.play();\n"
s += " // playButton.disabled = true; // Optional: disable button after playing\n"
s += " });\n"
s += "\n"
s += " // Update the video time when the slider value changes\n"
s += " slider.addEventListener('input', function () {\n"
s += " const sliderValue = slider.value;\n"
for i in range(len(dataset.video_frame_keys)):
s += f" const time{i} = (video{i}.duration * sliderValue) / 100;\n"
for i in range(len(dataset.video_frame_keys)):
s += f" video{i}.currentTime = time{i};\n"
s += " });\n"
s += "\n"
s += " // Synchronize slider with the video's current time\n"
s += " const syncSlider = (video) => {\n"
s += " video.addEventListener('timeupdate', function () {\n"
s += " if (video.duration) {\n"
s += " const pc = (100 / video.duration) * video.currentTime;\n"
s += " slider.value = pc;\n"
s += " const index = Math.floor(pc * dygraph.numRows() / 100);\n"
s += " dygraph.setSelection(index, undefined, true, true);\n"
s += " }\n"
s += " });\n"
s += " };\n"
s += "\n"
for i in range(len(dataset.video_frame_keys)):
s += f" syncSlider(video{i});\n"
s += "\n"
s += "});\n"
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w", encoding="utf-8") as f:
f.write(s)
def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset):
"""Write an html file containg video feeds and timeseries associated to an episode."""
soup, head, body = create_html_page("")
css_style = soup.new_tag("style")
css_style.string = ""
css_style.string += "#labels > span.highlight {\n"
css_style.string += " border: 1px solid grey;\n"
css_style.string += "}"
head.append(css_style)
# Add videos from camera feeds
videos_control_div = soup.new_tag("div")
body.append(videos_control_div)
videos_div = soup.new_tag("div")
videos_control_div.append(videos_div)
def create_video(id, src):
video = soup.new_tag("video", id=id, width="320", height="240", controls="")
source = soup.new_tag("source", src=src, type="video/mp4")
video.string = "Your browser does not support the video tag."
video.append(source)
return video
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
for key in dataset.video_frame_keys:
# Example of video_path: 'videos/observation.image_episode_000004.mp4'
video_path = dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
videos_div.append(create_video(f"video_{key}", video_path))
# Add controls for videos and graph
control_div = soup.new_tag("div")
videos_control_div.append(control_div)
button_div = soup.new_tag("div")
control_div.append(button_div)
button = soup.new_tag("button", id="playButton")
button.string = "Play Videos"
button_div.append(button)
slider_div = soup.new_tag("div")
control_div.append(slider_div)
slider = soup.new_tag("input", type="range", id="videoControl", min="0", max="100", value="0", step="1")
control_div.append(slider)
# Add graph of states/actions, and its labels
graph_labels_div = soup.new_tag("div", style="display: flex;")
body.append(graph_labels_div)
graph_div = soup.new_tag("div", id="graph", style="flex: 1; width: 85%")
graph_labels_div.append(graph_div)
labels_div = soup.new_tag("div", id="labels", style="flex: 1; width: 15%")
graph_labels_div.append(labels_div)
# add dygraph library
script = soup.new_tag("script", type="text/javascript", src=js_fname)
body.append(script)
script_dygraph = soup.new_tag(
"script",
type="text/javascript",
src="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.js",
)
body.append(script_dygraph)
link_dygraph = soup.new_tag(
"link", rel="stylesheet", href="https://cdn.jsdelivr.net/npm/dygraphs@2.1.0/dist/dygraph.min.css"
)
body.append(link_dygraph)
# Write as a html file
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w", encoding="utf-8") as f:
f.write(soup.prettify())
def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames, dataset):
"""Write an html file containing information related to the dataset and a list of links to
html pages of episodes."""
soup, head, body = create_html_page("TODO")
h3 = soup.new_tag("h3")
h3.string = "TODO"
body.append(h3)
ul_info = soup.new_tag("ul")
body.append(ul_info)
li_info = soup.new_tag("li")
li_info.string = f"Number of samples/frames: {dataset.num_samples}"
ul_info.append(li_info)
li_info = soup.new_tag("li")
li_info.string = f"Number of episodes: {dataset.num_episodes}"
ul_info.append(li_info)
li_info = soup.new_tag("li")
li_info.string = f"Frames per second: {dataset.fps}"
ul_info.append(li_info)
# li_info = soup.new_tag("li")
# li_info.string = f"Size: {format_big_number(dataset.hf_dataset.info.size_in_bytes)}B"
# ul_info.append(li_info)
ul = soup.new_tag("ul")
body.append(ul)
for ep_idx, ep_html_fname in zip(ep_indices, ep_html_fnames, strict=False):
li = soup.new_tag("li")
ul.append(li)
a = soup.new_tag("a", href=ep_html_fname)
a.string = f"Episode number {ep_idx}"
li.append(a)
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w", encoding="utf-8") as f:
f.write(soup.prettify())
def run_inference(dataset, episode_index, policy, num_workers=4, batch_size=32, device="cuda"):
policy.eval()
policy.to(device)
logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
@@ -124,70 +396,104 @@ def visualize_dataset(
sampler=episode_sampler,
)
logging.info("Starting Rerun")
if mode not in ["local", "distant"]:
raise ValueError(mode)
spawn_local_viewer = mode == "local" and not save
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
# when iterating on a dataloader with `num_workers` > 0
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
gc.collect()
if mode == "distant":
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
logging.info("Logging to Rerun")
logging.info("Running inference")
inference_results = {}
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
# iterate over the batch
for i in range(len(batch["index"])):
rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
with torch.inference_mode():
output_dict = policy.forward(batch)
# display each camera image
for key in dataset.camera_keys:
# TODO(rcadene): add `.compress()`? is it lossless?
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
for key in output_dict:
if key not in inference_results:
inference_results[key] = []
inference_results[key].append(output_dict[key].to("cpu"))
# display each dimension of action space (e.g. actuators command)
if "action" in batch:
for dim_idx, val in enumerate(batch["action"][i]):
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
for key in inference_results:
inference_results[key] = torch.cat(inference_results[key])
# display each dimension of observed state space (e.g. agent position in joint space)
if "observation.state" in batch:
for dim_idx, val in enumerate(batch["observation.state"][i]):
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
return inference_results
if "next.done" in batch:
rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
if "next.reward" in batch:
rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
def visualize_dataset(
repo_id: str,
episode_indices: list[int] = None,
output_dir: Path | None = None,
serve: bool = True,
port: int = 9090,
force_overwrite: bool = True,
policy_repo_id: str | None = None,
policy_ckpt_path: Path | None = None,
batch_size: int = 32,
num_workers: int = 4,
) -> Path | None:
init_logging()
if "next.success" in batch:
rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
has_policy = policy_repo_id or policy_ckpt_path
if mode == "local" and save:
# save .rrd locally
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
repo_id_str = repo_id.replace("/", "_")
rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
rr.save(rrd_path)
return rrd_path
if has_policy:
logging.info("Loading policy")
if policy_repo_id:
pretrained_policy_path = Path(snapshot_download(policy_repo_id))
elif policy_ckpt_path:
pretrained_policy_path = Path(policy_ckpt_path)
policy = ACTPolicy.from_pretrained(pretrained_policy_path)
with open(pretrained_policy_path / "config.yaml") as f:
cfg = yaml.safe_load(f)
delta_timestamps = cfg["training"]["delta_timestamps"]
else:
delta_timestamps = None
elif mode == "distant":
# stop the process from exiting since it is serving the websocket connection
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("Ctrl-C received. Exiting.")
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
if not dataset.video:
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
if output_dir is None:
output_dir = f"outputs/visualize_dataset/{repo_id}"
output_dir = Path(output_dir)
if force_overwrite and output_dir.exists():
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files.
ln_videos_dir = output_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
if episode_indices is None:
episode_indices = list(range(dataset.num_episodes))
logging.info("Writing html")
ep_html_fnames = []
for episode_index in tqdm.tqdm(episode_indices):
inference_results = None
if has_policy:
inference_results_path = output_dir / f"episode_{episode_index}.safetensors"
if inference_results_path.exists():
inference_results = load_file(inference_results_path)
else:
inference_results = run_inference(dataset, episode_index, policy)
save_file(inference_results, inference_results_path)
# write states and actions in a csv
ep_csv_fname = f"episode_{episode_index}.csv"
write_episode_data_csv(output_dir, ep_csv_fname, episode_index, dataset, inference_results)
js_fname = f"episode_{episode_index}.js"
write_episode_data_js(output_dir, js_fname, ep_csv_fname, dataset)
# write a html page to view videos and timeseries
ep_html_fname = f"episode_{episode_index}.html"
write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_index, dataset)
ep_html_fnames.append(ep_html_fname)
write_episodes_list_html(output_dir, "index.html", episode_indices, ep_html_fnames, dataset)
if serve:
run_server(output_dir, port)
def main():
@@ -197,13 +503,51 @@ def main():
"--repo-id",
type=str,
required=True,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
"--episode-index",
"--episode-indices",
type=int,
required=True,
help="Episode to visualize.",
nargs="*",
default=None,
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
)
parser.add_argument(
"--serve",
type=int,
default=1,
help="Launch web server.",
)
parser.add_argument(
"--port",
type=int,
default=9090,
help="Web port used by the http server.",
)
parser.add_argument(
"--force-overwrite",
type=int,
default=1,
help="Delete the output directory if it exists already.",
)
parser.add_argument(
"--policy-repo-id",
type=str,
default=None,
help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).",
)
parser.add_argument(
"--policy-ckpt-path",
type=str,
default=None,
help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).",
)
parser.add_argument(
"--batch-size",
@@ -217,43 +561,6 @@ def main():
default=4,
help="Number of processes of Dataloader for loading the data.",
)
parser.add_argument(
"--mode",
type=str,
default="local",
help=(
"Mode of viewing between 'local' or 'distant'. "
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
"'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
),
)
parser.add_argument(
"--web-port",
type=int,
default=9090,
help="Web port for rerun.io when `--mode distant` is set.",
)
parser.add_argument(
"--ws-port",
type=int,
default=9087,
help="Web socket port for rerun.io when `--mode distant` is set.",
)
parser.add_argument(
"--save",
type=int,
default=0,
help=(
"Save a .rrd file in the directory provided by `--output-dir`. "
"It also deactivates the spawning of a viewer. ",
"Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
),
)
parser.add_argument(
"--output-dir",
type=str,
help="Directory path to write a .rrd file when `--save 1` is set.",
)
args = parser.parse_args()
visualize_dataset(**vars(args))

View File

@@ -0,0 +1,263 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
Note: The last frame of the episode doesnt always correspond to a final state.
That's because our datasets are composed of transition from state to state up to
the antepenultimate state associated to the ultimate action to arrive in the final state.
However, there might not be a transition from a final state to another state.
Note: This script aims to visualize the data used to train the neural networks.
~What you see is what you get~. When visualizing image modality, it is often expected to observe
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
save disk space. The compression factor applied has been tuned to not affect success rate.
Examples:
- Visualize data stored on a local machine:
```
local$ python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0
```
- Visualize data stored on a distant machine with a local viewer:
```
distant$ python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0 \
--save 1 \
--output-dir path/to/directory
local$ scp distant:path/to/directory/lerobot_pusht_episode_0.rrd .
local$ rerun lerobot_pusht_episode_0.rrd
```
- Visualize data stored on a distant machine through streaming:
(You need to forward the websocket port to the distant machine, with
`ssh -L 9087:localhost:9087 username@remote-host`)
```
distant$ python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0 \
--mode distant \
--ws-port 9087
local$ rerun ws://localhost:9087
```
"""
import argparse
import gc
import logging
import time
from pathlib import Path
import rerun as rr
import torch
import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset, episode_index):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx)
def __iter__(self):
return iter(self.frame_ids)
def __len__(self):
return len(self.frame_ids)
def to_hwc_uint8_numpy(chw_float32_torch):
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
return hwc_uint8_numpy
def visualize_dataset(
repo_id: str,
episode_index: int,
batch_size: int = 32,
num_workers: int = 0,
mode: str = "local",
web_port: int = 9090,
ws_port: int = 9087,
save: bool = False,
output_dir: Path | None = None,
) -> Path | None:
if save:
assert (
output_dir is not None
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id)
logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
sampler=episode_sampler,
)
logging.info("Starting Rerun")
if mode not in ["local", "distant"]:
raise ValueError(mode)
spawn_local_viewer = mode == "local" and not save
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
# when iterating on a dataloader with `num_workers` > 0
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
gc.collect()
if mode == "distant":
rr.serve(open_browser=False, web_port=web_port, ws_port=ws_port)
logging.info("Logging to Rerun")
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
# iterate over the batch
for i in range(len(batch["index"])):
rr.set_time_sequence("frame_index", batch["frame_index"][i].item())
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
# display each camera image
for key in dataset.camera_keys:
# TODO(rcadene): add `.compress()`? is it lossless?
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
# display each dimension of action space (e.g. actuators command)
if "action" in batch:
for dim_idx, val in enumerate(batch["action"][i]):
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
# display each dimension of observed state space (e.g. agent position in joint space)
if "observation.state" in batch:
for dim_idx, val in enumerate(batch["observation.state"][i]):
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
if "next.done" in batch:
rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
if "next.reward" in batch:
rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
if "next.success" in batch:
rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))
if mode == "local" and save:
# save .rrd locally
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
repo_id_str = repo_id.replace("/", "_")
rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
rr.save(rrd_path)
return rrd_path
elif mode == "distant":
# stop the process from exiting since it is serving the websocket connection
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("Ctrl-C received. Exiting.")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
)
parser.add_argument(
"--episode-index",
type=int,
required=True,
help="Episode to visualize.",
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size loaded by DataLoader.",
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of processes of Dataloader for loading the data.",
)
parser.add_argument(
"--mode",
type=str,
default="local",
help=(
"Mode of viewing between 'local' or 'distant'. "
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
"'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
),
)
parser.add_argument(
"--web-port",
type=int,
default=9090,
help="Web port for rerun.io when `--mode distant` is set.",
)
parser.add_argument(
"--ws-port",
type=int,
default=9087,
help="Web socket port for rerun.io when `--mode distant` is set.",
)
parser.add_argument(
"--save",
type=int,
default=0,
help=(
"Save a .rrd file in the directory provided by `--output-dir`. "
"It also deactivates the spawning of a viewer. ",
"Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
),
)
parser.add_argument(
"--output-dir",
type=str,
help="Directory path to write a .rrd file when `--save 1` is set.",
)
args = parser.parse_args()
visualize_dataset(**vars(args))
if __name__ == "__main__":
main()

View File

@@ -25,9 +25,8 @@ from lerobot.scripts.visualize_dataset import visualize_dataset
def test_visualize_dataset(tmpdir, repo_id):
rrd_path = visualize_dataset(
repo_id,
episode_index=0,
batch_size=32,
save=True,
episode_indices=[0],
output_dir=tmpdir,
serve=False,
)
assert rrd_path.exists()