Compare commits

...

4 Commits

Author SHA1 Message Date
Remi Cadene
9ddbbd8e80 WIP 2024-08-06 17:17:07 +03:00
Remi Cadene
1da5caaf4b Revert "Revove inference"
This reverts commit ca7f207d74.
2024-08-06 17:16:42 +03:00
Remi Cadene
ca7f207d74 Revove inference 2024-08-06 17:15:52 +03:00
Remi Cadene
6b9dcadbf7 Add visualize_dataset_html.py 2024-08-06 17:07:48 +03:00
20 changed files with 1150 additions and 208 deletions

View File

@@ -44,7 +44,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
repo_id: str,
root: Path | None = DATA_DIR,
root: Path | None = None,
split: str = "train",
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
@@ -53,22 +53,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
super().__init__()
self.repo_id = repo_id
self.root = root
if self.root is None and DATA_DIR is not None:
self.root = DATA_DIR
self.split = split
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, self.root, split)
if split == "train":
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
else:
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
self.hf_dataset = reset_episode_index(self.hf_dataset)
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
self.info = load_info(repo_id, CODEBASE_VERSION, root)
self.stats = load_stats(repo_id, CODEBASE_VERSION, self.root)
self.info = load_info(repo_id, CODEBASE_VERSION, self.root)
if self.video:
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, self.root)
self.video_backend = video_backend if video_backend is not None else "pyav"
@property

View File

@@ -233,9 +233,6 @@ class Logger:
if self._wandb is not None:
for k, v in d.items():
if not isinstance(v, (int, float, str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step)

View File

@@ -134,25 +134,26 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
bsize = actions_hat.shape[0]
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
out_dict = {}
out_dict["l1_loss"] = l1_loss
loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
else:
loss_dict["loss"] = l1_loss
out_dict["loss"] = l1_loss
return loss_dict
out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"]
return out_dict
class ACTTemporalEnsembler:

View File

@@ -341,7 +341,11 @@ class DiffusionModel(nn.Module):
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
return loss.mean()
# Compute average per item in the batch
bsize = loss.shape[0]
loss = loss.reshape(bsize, -1).mean(1)
return loss
class SpatialSoftmax(nn.Module):

View File

@@ -396,51 +396,39 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
# predicted from the (target model's) observation encoder.
consistency_loss = (
(
temporal_loss_coeffs
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
# `z_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# `z_targets` depends on the next observation.
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
)
temporal_loss_coeffs
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
# `z_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# `z_targets` depends on the next observation.
* ~batch["observation.state_is_pad"][1:]
).sum(0)
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
# rewards.
reward_loss = (
(
temporal_loss_coeffs
* F.mse_loss(reward_preds, reward, reduction="none")
* ~batch["next.reward_is_pad"]
# `reward_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
)
temporal_loss_coeffs
* F.mse_loss(reward_preds, reward, reduction="none")
* ~batch["next.reward_is_pad"]
# `reward_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
).sum(0)
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
q_value_loss = (
(
temporal_loss_coeffs
* F.mse_loss(
q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
)
.sum(0)
.mean()
)
temporal_loss_coeffs
* F.mse_loss(
q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
).sum(0)
# Compute state value loss as in eqn 3 of FOWM.
diff = v_targets - v_preds
# Expectile loss penalizes:
@@ -450,16 +438,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight)
) * (diff**2)
v_value_loss = (
(
temporal_loss_coeffs
* raw_v_value_loss
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
)
.sum(0)
.mean()
)
temporal_loss_coeffs
* raw_v_value_loss
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
).sum(0)
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
# We won't need these gradients again so detach.
@@ -492,7 +476,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# `action_preds` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
).mean()
).sum(0)
loss = (
self.config.consistency_coeff * consistency_loss
@@ -504,13 +488,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
info.update(
{
"consistency_loss": consistency_loss.item(),
"reward_loss": reward_loss.item(),
"Q_value_loss": q_value_loss.item(),
"V_value_loss": v_value_loss.item(),
"pi_loss": pi_loss.item(),
"consistency_loss": consistency_loss,
"reward_loss": reward_loss,
"Q_value_loss": q_value_loss,
"V_value_loss": v_value_loss,
"pi_loss": pi_loss,
"loss": loss,
"sum_loss": loss.item() * self.config.horizon,
"sum_loss": loss * self.config.horizon,
}
)

View File

@@ -13,7 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from torch import nn
@@ -47,3 +53,26 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
Note: assumes that all parameters have the same dtype.
"""
return next(iter(module.parameters())).dtype
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
try:
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
return pretrained_policy_path

View File

@@ -24,7 +24,7 @@ training:
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 50

View File

@@ -50,7 +50,7 @@ training:
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 50

View File

@@ -48,7 +48,7 @@ training:
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 50

View File

@@ -56,9 +56,6 @@ import einops
import gymnasium as gym
import numpy as np
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from torch import Tensor, nn
from tqdm import trange
@@ -68,7 +65,7 @@ from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.policies.utils import get_device_from_parameters, get_pretrained_policy_path
from lerobot.common.utils.io_utils import write_video
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
@@ -501,29 +498,6 @@ def main(
logging.info("End of eval")
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
try:
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
return pretrained_policy_path
if __name__ == "__main__":
init_logging()

View File

@@ -120,8 +120,7 @@ def update_policy(
policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"]
loss = output_dict["loss"].mean()
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
@@ -150,14 +149,12 @@ def update_policy(
policy.update()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"],
"update_s": time.perf_counter() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"},
**{k: v.detach().mean().item() for k, v in output_dict.items() if "loss" in k},
**{k: v for k, v in output_dict.items() if "loss" not in k},
}
info.update({k: v for k, v in output_dict.items() if k not in info})
return info

View File

@@ -108,8 +108,8 @@ def visualize_dataset(
web_port: int = 9090,
ws_port: int = 9087,
save: bool = False,
output_dir: Path | None = None,
root: Path | None = None,
output_dir: Path | None = None,
) -> Path | None:
if save:
assert (
@@ -209,6 +209,18 @@ def main():
required=True,
help="Episode to visualize.",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Directory path to write a .rrd file when `--save 1` is set.",
)
parser.add_argument(
"--batch-size",
type=int,
@@ -254,17 +266,6 @@ def main():
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
),
)
parser.add_argument(
"--output-dir",
type=str,
help="Directory path to write a .rrd file when `--save 1` is set.",
)
parser.add_argument(
"--root",
type=str,
help="Root directory for a dataset stored on a local machine.",
)
args = parser.parse_args()
visualize_dataset(**vars(args))

View File

@@ -0,0 +1,467 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
Note: The last frame of the episode doesnt always correspond to a final state.
That's because our datasets are composed of transition from state to state up to
the antepenultimate state associated to the ultimate action to arrive in the final state.
However, there might not be a transition from a final state to another state.
Note: This script aims to visualize the data used to train the neural networks.
~What you see is what you get~. When visualizing image modality, it is often expected to observe
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
save disk space. The compression factor applied has been tuned to not affect success rate.
Example of usage:
- Visualize data stored on a local machine:
```bash
local$ python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/pusht
local$ open http://localhost:9090
```
- Visualize data stored on a distant machine with a local viewer:
```bash
distant$ python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/pusht
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
local$ open http://localhost:9090
```
- Select episodes to visualize:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/pusht \
--episodes 7 3 5 1 4
```
- Run inference of a policy on the dataset and visualize the results:
```bash
python lerobot/scripts/visualize_dataset_html.py \
--repo-id lerobot/pusht \
--episodes 7 3 5 1 4
-p lerobot/diffusion_pusht \
--policy-overrides device=cpu
```
"""
import argparse
import logging
import shutil
import warnings
from pathlib import Path
import torch
import tqdm
from flask import Flask, redirect, render_template, url_for
from safetensors.torch import load_file, save_file
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.utils import get_pretrained_policy_path
from lerobot.common.utils.utils import init_hydra_config, init_logging
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset, episode_index):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx)
def __iter__(self):
return iter(self.frame_ids)
def __len__(self):
return len(self.frame_ids)
def run_server(
dataset: LeRobotDataset,
episodes: list[int],
host: str,
port: str,
static_folder: Path,
template_folder: Path,
has_policy: bool = False,
):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@app.route("/")
def index():
# home page redirects to the first episode page
[dataset_namespace, dataset_name] = dataset.repo_id.split("/")
first_episode_id = episodes[0]
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=first_episode_id,
)
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id):
dataset_info = {
"repo_id": dataset.repo_id,
"num_samples": dataset.num_samples,
"num_episodes": dataset.num_episodes,
"fps": dataset.fps,
}
video_paths = get_episode_video_paths(dataset, episode_id)
videos_info = [
{"url": url_for("static", filename=video_path), "filename": Path(video_path).name}
for video_path in video_paths
]
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
return render_template(
"visualize_dataset_template.html",
episode_id=episode_id,
episodes=episodes,
dataset_info=dataset_info,
videos_info=videos_info,
ep_csv_url=ep_csv_url,
has_policy=has_policy,
)
app.run(host=host, port=port)
def get_ep_csv_fname(episode_id: int):
ep_csv_fname = f"episode_{episode_id}.csv"
return ep_csv_fname
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None):
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
has_state = "observation.state" in dataset.hf_dataset.features
has_action = "action" in dataset.hf_dataset.features
has_inference = inference_results is not None
# init header of csv with state and action names
header = ["timestamp"]
if has_state:
dim_state = len(dataset.hf_dataset["observation.state"][0])
header += [f"state_{i}" for i in range(dim_state)]
if has_action:
dim_action = len(dataset.hf_dataset["action"][0])
header += [f"action_{i}" for i in range(dim_action)]
if has_inference:
if "action" in inference_results:
dim_pred_action = inference_results["action"].shape[1]
header += [f"pred_action_{i}" for i in range(dim_pred_action)]
for key in inference_results:
if "loss" in key:
header += [key]
columns = ["timestamp"]
if has_state:
columns += ["observation.state"]
if has_action:
columns += ["action"]
rows = []
data = dataset.hf_dataset.select_columns(columns)
for i in range(from_idx, to_idx):
row = [data[i]["timestamp"].item()]
if has_state:
row += data[i]["observation.state"].tolist()
if has_action:
row += data[i]["action"].tolist()
rows.append(row)
if has_inference:
num_frames = len(rows)
if "action" in inference_results:
assert num_frames == inference_results["action"].shape[0]
for i in range(num_frames):
rows[i] += inference_results["action"][i].tolist()
for key in inference_results:
if "loss" in key:
assert num_frames == inference_results[key].shape[0]
for i in range(num_frames):
rows[i] += [inference_results[key][i].item()]
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w") as f:
f.write(",".join(header) + "\n")
for row in rows:
row_str = [str(col) for col in row]
f.write(",".join(row_str) + "\n")
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
return [
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
for key in dataset.video_frame_keys
]
def run_inference(
dataset, episode_index, policy, policy_method="select_action", num_workers=4, batch_size=32, device="cuda"
):
if policy_method not in ["select_action", "forward"]:
raise ValueError(
f"`policy_method` is expected to be 'select_action' or 'forward', but '{policy_method}' is provided instead."
)
policy.eval()
policy.to(device)
logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
# When using `select_action`, we set batch size 1 so that we feed 1 frame at a time, in a continuous fashion.
batch_size=1 if policy_method == "select_action" else batch_size,
sampler=episode_sampler,
drop_last=False,
)
warned_ndim_eq_0 = False
warned_ndim_gt_2 = False
logging.info("Running inference")
inference_results = {}
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
with torch.inference_mode():
if policy_method == "select_action":
gt_action = batch.pop("action")
output_dict = {"action": policy.select_action(batch)}
batch["action"] = gt_action
elif policy_method == "forward":
output_dict = policy.forward(batch)
# TODO(rcadene): Save and display all predicted actions at a given timestamp
# Save predicted action for the next timestamp only
output_dict["action"] = output_dict["action"][:, 0, :]
for key in output_dict:
if output_dict[key].ndim == 0:
if not warned_ndim_eq_0:
warnings.warn(
f"Ignore output key '{key}'. Its value is a scalar instead of a vector. It might have been aggregated over the batch dimension (e.g. `loss.mean()`).",
stacklevel=1,
)
warned_ndim_eq_0 = True
continue
if output_dict[key].ndim > 2:
if not warned_ndim_gt_2:
warnings.warn(
f"Ignore output key '{key}'. Its value is a tensor of {output_dict[key].ndim} dimensions instead of a vector.",
stacklevel=1,
)
warned_ndim_gt_2 = True
continue
if key not in inference_results:
inference_results[key] = []
inference_results[key].append(output_dict[key].to("cpu"))
for key in inference_results:
inference_results[key] = torch.cat(inference_results[key])
return inference_results
def visualize_dataset_html(
repo_id: str,
root: Path | None = None,
episodes: list[int] = None,
output_dir: Path | None = None,
serve: bool = True,
host: str = "127.0.0.1",
port: int = 9090,
force_override: bool = False,
policy_method: str = "select_action",
pretrained_policy_name_or_path: str | None = None,
policy_overrides: list[str] | None = None,
) -> Path | None:
init_logging()
has_policy = pretrained_policy_name_or_path is not None
if has_policy:
logging.info("Loading policy")
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
dataset = make_dataset(hydra_cfg)
policy = make_policy(hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
if policy_method == "select_action":
# Do not load previous observations or future actions, to simulate that the observations come from
# an environment.
dataset.delta_timestamps = None
else:
dataset = LeRobotDataset(repo_id, root=root)
if not dataset.video:
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
if output_dir is None:
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
if has_policy:
ckpt_str = pretrained_policy_path.parts[-2]
exp_name = pretrained_policy_path.parts[-4]
output_dir += f"_{exp_name}_{ckpt_str}_{policy_method}"
output_dir = Path(output_dir)
if output_dir.exists():
if force_override:
shutil.rmtree(output_dir)
else:
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
output_dir.mkdir(parents=True, exist_ok=True)
# Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files.
static_dir = output_dir / "static"
static_dir.mkdir(parents=True, exist_ok=True)
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
template_dir = Path(__file__).resolve().parent.parent / "templates"
if episodes is None:
episodes = list(range(dataset.num_episodes))
logging.info("Writing CSV files")
for episode_index in tqdm.tqdm(episodes):
inference_results = None
if has_policy:
inference_results_path = output_dir / f"episode_{episode_index}.safetensors"
if inference_results_path.exists():
inference_results = load_file(inference_results_path)
else:
inference_results = run_inference(
dataset,
episode_index,
policy,
policy_method,
num_workers=hydra_cfg.training.num_workers,
batch_size=hydra_cfg.training.batch_size,
device=hydra_cfg.device,
)
inference_results_path.parent.mkdir(parents=True, exist_ok=True)
save_file(inference_results, inference_results_path)
# write states and actions in a csv (it can be slow for big datasets)
ep_csv_fname = get_ep_csv_fname(episode_index)
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, inference_results)
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir, has_policy)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--episodes",
type=int,
nargs="*",
default=None,
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
)
parser.add_argument(
"--serve",
type=int,
default=1,
help="Launch web server.",
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Web host used by the http server.",
)
parser.add_argument(
"--port",
type=int,
default=9090,
help="Web port used by the http server.",
)
parser.add_argument(
"--force-override",
type=int,
default=0,
help="Delete the output directory if it exists already.",
)
parser.add_argument(
"--policy-method",
type=str,
default="select_action",
choices=["select_action", "forward"],
help="Python method used to run the inference. By default, set to `select_action` used during evaluation to output the sequence of actions. Can bet set to `forward` used during training to compute the loss.",
)
parser.add_argument(
"-p",
"--pretrained-policy-name-or-path",
type=str,
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`."
),
)
parser.add_argument(
"--policy-overrides",
nargs="*",
help="Any key=value arguments to override policy config values (use dots for.nested=overrides)",
)
args = parser.parse_args()
visualize_dataset_html(**vars(args))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,360 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<!-- # TODO(rcadene, mishig25): store the js files locally -->
<script src="https://cdnjs.cloudflare.com/ajax/libs/alpinejs/3.13.5/cdn.min.js" defer></script>
<script src="https://cdn.jsdelivr.net/npm/dygraphs@2.2.1/dist/dygraph.min.js" type="text/javascript"></script>
<script src="https://cdn.tailwindcss.com"></script>
<title>{{ dataset_info.repo_id }} episode {{ episode_id }}</title>
</head>
<!-- Use [Alpin.js](https://alpinejs.dev), a lightweight and easy to learn JS framework -->
<!-- Use [tailwindcss](https://tailwindcss.com/), CSS classes for styling html -->
<!-- Use [dygraphs](https://dygraphs.com/), a lightweight JS charting library -->
<body class="flex h-screen max-h-screen bg-slate-950 text-gray-200" x-data="createAlpineData()" @keydown.window="(e) => {
// Use the space bar to play and pause, instead of default action (e.g. scrolling)
const { keyCode, key } = e;
if (keyCode === 32 || key === ' ') {
e.preventDefault();
$refs.btnPause.classList.contains('hidden') ? $refs.btnPlay.click() : $refs.btnPause.click();
}else if (key === 'ArrowDown' || key === 'ArrowUp'){
const nextEpisodeId = key === 'ArrowDown' ? {{ episode_id }} + 1 : {{ episode_id }} - 1;
const lowestEpisodeId = {{ episodes }}.at(0);
const highestEpisodeId = {{ episodes }}.at(-1);
if(nextEpisodeId >= lowestEpisodeId && nextEpisodeId <= highestEpisodeId){
window.location.href = `./episode_${nextEpisodeId}`;
}
}
}">
<!-- Sidebar -->
<div x-ref="sidebar" class="w-60 bg-slate-900 p-5 break-words max-h-screen overflow-y-auto">
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
<ul>
<li>
Number of samples/frames: {{ dataset_info.num_samples }}
</li>
<li>
Number of episodes: {{ dataset_info.num_episodes }}
</li>
<li>
Frames per second: {{ dataset_info.fps }}
</li>
</ul>
<p>Episodes:</p>
<ul class="ml-2">
{% for episode in episodes %}
<li class="font-mono text-sm mt-0.5">
<a href="episode_{{ episode }}" class="underline {% if episode_id == episode %}font-bold -ml-1{% endif %}">
Episode {{ episode }}
</a>
</li>
{% endfor %}
</ul>
</div>
<!-- Toggle sidebar button -->
<button class="flex items-center opacity-50 hover:opacity-100 mx-1"
@click="() => ($refs.sidebar.classList.toggle('hidden'))" title="Toggle sidebar">
<div class="bg-slate-500 w-2 h-10 rounded-full"></div>
</button>
<!-- Content -->
<div class="flex-1 max-h-screen flex flex-col gap-4 overflow-y-auto">
<h1 class="text-xl font-bold mt-4 font-mono">
Episode {{ episode_id }}
</h1>
<!-- Videos -->
<div class="flex flex-wrap gap-1">
{% for video_info in videos_info %}
<div class="max-w-96">
<p class="text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
<video autoplay muted loop type="video/mp4" class="min-w-64" @timeupdate="() => {
if (video.duration) {
const time = video.currentTime;
const pc = (100 / video.duration) * time;
$refs.slider.value = pc;
dygraphTime = time;
dygraphIndex = Math.floor(pc * dygraph.numRows() / 100);
dygraph.setSelection(dygraphIndex, undefined, true, true);
$refs.timer.textContent = formatTime(time) + ' / ' + formatTime(video.duration);
updateTimeQuery(time.toFixed(2));
}
}" @ended="() => {
$refs.btnPlay.classList.remove('hidden');
$refs.btnPause.classList.add('hidden');
}"
@loadedmetadata="() => ($refs.timer.textContent = formatTime(0) + ' / ' + formatTime(video.duration))">
<source src="{{ video_info.url }}">
Your browser does not support the video tag.
</video>
</div>
{% endfor %}
</div>
<!-- Shortcuts info -->
<div class="text-sm hidden md:block">
Hotkeys: <span class="font-mono">Space</span> to pause/unpause, <span class="font-mono">Arrow Down</span> to go to next episode, <span class="font-mono">Arrow Up</span> to go to previous episode.
</div>
<!-- Controllers -->
<div class="flex gap-1 text-3xl items-center">
<button x-ref="btnPlay" class="-rotate-90 hidden" class="-rotate-90" title="Play. Toggle with Space" @click="() => {
videos.forEach(video => video.play());
$refs.btnPlay.classList.toggle('hidden');
$refs.btnPause.classList.toggle('hidden');
}">🔽</button>
<button x-ref="btnPause" title="Pause. Toggle with Space" @click="() => {
videos.forEach(video => video.pause());
$refs.btnPlay.classList.toggle('hidden');
$refs.btnPause.classList.toggle('hidden');
}">⏸️</button>
<button title="Jump backward 5 seconds"
@click="() => (videos.forEach(video => (video.currentTime -= 5)))"></button>
<button title="Jump forward 5 seconds"
@click="() => (videos.forEach(video => (video.currentTime += 5)))"></button>
<button title="Rewind from start"
@click="() => (videos.forEach(video => (video.currentTime = 0.0)))">↩️</button>
<input x-ref="slider" max="100" min="0" step="1" type="range" value="0" class="w-80 mx-2" @input="() => {
const sliderValue = $refs.slider.value;
$refs.btnPause.click();
videos.forEach(video => {
const time = (video.duration * sliderValue) / 100;
video.currentTime = time;
});
}" />
<div x-ref="timer" class="font-mono text-sm border border-slate-500 rounded-lg px-1 py-0.5 shrink-0">0:00 /
0:00
</div>
</div>
<!-- Graph -->
<div class="flex gap-2 mb-4 flex-wrap">
<div>
<div id="graph" @mouseleave="() => {
dygraph.setSelection(dygraphIndex, undefined, true, true);
dygraphTime = video.currentTime;
}">
</div>
<p x-ref="graphTimer" class="font-mono ml-14 mt-4"
x-init="$watch('dygraphTime', value => ($refs.graphTimer.innerText = `Time: ${dygraphTime.toFixed(2)}s`))">
Time: 0.00s
</p>
</div>
<table class="text-sm border-collapse border border-slate-700" x-show="currentFrameData">
<thead>
<tr>
<th></th>
<template x-for="(_, colIndex) in Array.from({length: nColumns}, (_, index) => index)">
<th class="border border-slate-700">
<div class="flex gap-x-2 justify-between px-2">
<input type="checkbox" :checked="isColumnChecked(colIndex)"
@change="toggleColumn(colIndex)">
<p x-text="`${columnNames[colIndex]}`"></p>
</div>
</th>
</template>
</tr>
</thead>
<tbody>
<template x-for="(row, rowIndex) in rows">
<tr class="odd:bg-gray-800 even:bg-gray-900">
<td class="border border-slate-700">
<div class="flex gap-x-2 w-24 font-semibold px-1">
<input type="checkbox" :checked="isRowChecked(rowIndex)"
@change="toggleRow(rowIndex)">
<p x-text="`Motor ${rowIndex}`"></p>
</div>
</td>
<template x-for="(cell, colIndex) in row">
<td x-show="cell" class="border border-slate-700">
<div class="flex gap-x-2 w-24 justify-between px-2">
<input type="checkbox" x-model="cell.checked" @change="updateTableValues()">
<span x-text="`${cell.value.toFixed(2)}`"
:style="`color: ${cell.color}`"></span>
</div>
</td>
</template>
</tr>
</template>
</tbody>
</table>
<div id="labels" class="hidden">
</div>
</div>
</div>
<script>
function createAlpineData() {
return {
// state
dygraph: null,
currentFrameData: null,
columnNames: ["state", "action", "pred action"],
nColumns: {% if has_policy %}3{% else %}2{% endif %},
checked: [],
dygraphTime: 0.0,
dygraphIndex: 0,
videos: null,
video: null,
colors: null,
// alpine initialization
init() {
this.videos = document.querySelectorAll('video');
this.video = this.videos[0];
this.dygraph = new Dygraph(document.getElementById("graph"), '{{ ep_csv_url }}', {
pixelsPerPoint: 0.01,
legend: 'always',
labelsDiv: document.getElementById('labels'),
labelsKMB: true,
strokeWidth: 1.5,
pointClickCallback: (event, point) => {
this.dygraphTime = point.xval;
this.updateTableValues(this.dygraphTime);
},
highlightCallback: (event, x, points, row, seriesName) => {
this.dygraphTime = x;
this.updateTableValues(this.dygraphTime);
},
drawCallback: (dygraph, is_initial) => {
if (is_initial) {
// dygraph initialization
this.dygraph.setSelection(this.dygraphIndex, undefined, true, true);
this.colors = this.dygraph.getColors();
this.checked = Array(this.colors.length).fill(true);
const seriesNames = this.dygraph.getLabels().slice(1);
const colors = [];
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
let lightnessIdx = 0;
const chunkSize = Math.ceil(seriesNames.length / this.nColumns);
for (let i = 0; i < seriesNames.length; i += chunkSize) {
const lightness = LIGHTNESS[lightnessIdx];
for (let hue = 0; hue < 360; hue += parseInt(360/chunkSize)) {
const color = `hsl(${hue}, 100%, ${lightness}%)`;
colors.push(color);
}
lightnessIdx += 1;
}
this.dygraph.updateOptions({ colors });
this.colors = colors;
this.updateTableValues();
let url = new URL(window.location.href);
let params = new URLSearchParams(url.search);
let time = params.get("t");
if(time){
time = parseFloat(time);
this.videos.forEach(video => (video.currentTime = time));
}
}
},
});
},
//#region Table Data
// turn dygraph's 1D data (at a given time t) to 2D data that whose columns names are defined in this.columnNames.
// 2d data view is used to create html table element.
get rows() {
if (!this.currentFrameData) {
return [];
}
const columnSize = Math.ceil(this.currentFrameData.length / this.nColumns);
return Array.from({
length: columnSize
}, (_, rowIndex) => {
const row = [
this.currentFrameData[rowIndex] || null,
this.currentFrameData[rowIndex + columnSize] || null,
];
if (this.nColumns === 3) {
row.push(this.currentFrameData[rowIndex + 2 * columnSize] || null)
}
return row;
});
},
isRowChecked(rowIndex) {
return this.rows[rowIndex].every(cell => cell && cell.checked);
},
isColumnChecked(colIndex) {
return this.rows.every(row => row[colIndex] && row[colIndex].checked);
},
toggleRow(rowIndex) {
const newState = !this.isRowChecked(rowIndex);
this.rows[rowIndex].forEach(cell => {
if (cell) cell.checked = newState;
});
this.updateTableValues();
},
toggleColumn(colIndex) {
const newState = !this.isColumnChecked(colIndex);
this.rows.forEach(row => {
if (row[colIndex]) row[colIndex].checked = newState;
});
this.updateTableValues();
},
// given time t, update the values in the html table with "data[t]"
updateTableValues(time) {
if (!this.colors) {
return;
}
let pc = (100 / this.video.duration) * (time === undefined ? this.video.currentTime : time);
if (isNaN(pc)) pc = 0;
const index = Math.floor(pc * this.dygraph.numRows() / 100);
// slice(1) to remove the timestamp point that we do not need
const labels = this.dygraph.getLabels().slice(1);
const values = this.dygraph.rawData_[index].slice(1);
const checkedNew = this.currentFrameData ? this.currentFrameData.map(cell => cell.checked) : Array(
this.colors.length).fill(true);
this.currentFrameData = labels.map((label, idx) => ({
label,
value: values[idx],
color: this.colors[idx],
checked: checkedNew[idx],
}));
const shouldUpdateVisibility = !this.checked.every((value, index) => value === checkedNew[index]);
if (shouldUpdateVisibility) {
this.checked = checkedNew;
this.dygraph.setVisibility(this.checked);
}
},
//#endregion
updateTimeQuery(time) {
let url = new URL(window.location.href);
let params = new URLSearchParams(url.search);
params.set("t", time);
url.search = params.toString();
window.history.replaceState({}, '', url.toString());
},
formatTime(time) {
var hours = Math.floor(time / 3600);
var minutes = Math.floor((time % 3600) / 60);
var seconds = Math.floor(time % 60);
return (hours > 0 ? hours + ':' : '') + (minutes < 10 ? '0' + minutes : minutes) + ':' + (seconds <
10 ?
'0' + seconds : seconds);
}
};
}
</script>
</body>
</html>

126
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "absl-py"
@@ -192,6 +192,17 @@ charset-normalizer = ["charset-normalizer"]
html5lib = ["html5lib"]
lxml = ["lxml"]
[[package]]
name = "blinker"
version = "1.8.2"
description = "Fast, simple object-to-object and broadcast signaling"
optional = false
python-versions = ">=3.8"
files = [
{file = "blinker-1.8.2-py3-none-any.whl", hash = "sha256:1779309f71bf239144b9399d06ae925637cf6634cf6bd131104184531bf67c01"},
{file = "blinker-1.8.2.tar.gz", hash = "sha256:8f77b09d3bf7c795e969e9486f39c2c5e9c39d4ee07424be2bc594ece9642d83"},
]
[[package]]
name = "certifi"
version = "2024.7.4"
@@ -584,17 +595,6 @@ files = [
{file = "debugpy-1.8.2.zip", hash = "sha256:95378ed08ed2089221896b9b3a8d021e642c24edc8fef20e5d4342ca8be65c00"},
]
[[package]]
name = "decorator"
version = "4.4.2"
description = "Decorators for Humans"
optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*"
files = [
{file = "decorator-4.4.2-py2.py3-none-any.whl", hash = "sha256:41fa54c2a0cc4ba648be4fd43cff00aedf5b9465c9bf18d64325bc225f08f760"},
{file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"},
]
[[package]]
name = "deepdiff"
version = "7.0.1"
@@ -795,6 +795,7 @@ files = [
{file = "dora_rs-0.3.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:01f811d0c6722f74743c153a7be0144686daeafa968c473e60f6b6c5dc8f5bff"},
{file = "dora_rs-0.3.5-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:a36e97d31eeb66e6d5913130695d188ceee1248029961012a8b4f59fd3f58670"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25d620123a733661dc740ef2b456601ddbaa69ae2b50d8141daa3c684bda385c"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a9fdc4e73578bebb1c8d0f8bea2243a5a9e179f08c74d98576123b59b75e5cac"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e65830634c58158557f0ab90e5d1f492bcbc6b74587b05825ba4c20b634dc1bd"},
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c01f9ab8f93295341aeab2d606d484d9cff9d05f57581e2180433ec8e0d38307"},
{file = "dora_rs-0.3.5-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5d6d46a49a34cd7e4f74496a1089b9a1b78282c219a28d98fe031a763e92d530"},
@@ -892,6 +893,28 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1
testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"]
typing = ["typing-extensions (>=4.8)"]
[[package]]
name = "flask"
version = "3.0.3"
description = "A simple framework for building complex web applications."
optional = false
python-versions = ">=3.8"
files = [
{file = "flask-3.0.3-py3-none-any.whl", hash = "sha256:34e815dfaa43340d1d15a5c3a02b8476004037eb4840b34910c6e21679d288f3"},
{file = "flask-3.0.3.tar.gz", hash = "sha256:ceb27b0af3823ea2737928a4d99d125a06175b8512c445cbd9a9ce200ef76842"},
]
[package.dependencies]
blinker = ">=1.6.2"
click = ">=8.1.3"
itsdangerous = ">=2.1.2"
Jinja2 = ">=3.1.2"
Werkzeug = ">=3.0.0"
[package.extras]
async = ["asgiref (>=3.2)"]
dotenv = ["python-dotenv"]
[[package]]
name = "frozenlist"
version = "1.4.1"
@@ -1550,6 +1573,17 @@ files = [
{file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"},
]
[[package]]
name = "itsdangerous"
version = "2.2.0"
description = "Safely pass data to untrusted environments and back."
optional = false
python-versions = ">=3.8"
files = [
{file = "itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef"},
{file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"},
]
[[package]]
name = "jinja2"
version = "3.1.4"
@@ -1741,9 +1775,13 @@ files = [
{file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"},
{file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b590b39ef90c6b22ec0be925b211298e810b4856909c8ca60d27ffbca6c12e6"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:c2faf60c583af0d135e853c86ac2735ce178f0e338a3c7f9ae8f622fd2eb788c"},
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7ff762670cada8e05b32bf1e4dc50b140790909caa8303cfddc4d702b71ea184"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:a6d2092797b388342c1bc932077ad232f914351932353e2e8706851c870bca1f"},
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"},
{file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"},
{file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"},
@@ -1901,30 +1939,6 @@ files = [
intel-openmp = "==2021.*"
tbb = "==2021.*"
[[package]]
name = "moviepy"
version = "1.0.3"
description = "Video editing with Python"
optional = false
python-versions = "*"
files = [
{file = "moviepy-1.0.3.tar.gz", hash = "sha256:2884e35d1788077db3ff89e763c5ba7bfddbd7ae9108c9bc809e7ba58fa433f5"},
]
[package.dependencies]
decorator = ">=4.0.2,<5.0"
imageio = {version = ">=2.5,<3.0", markers = "python_version >= \"3.4\""}
imageio_ffmpeg = {version = ">=0.2.0", markers = "python_version >= \"3.4\""}
numpy = {version = ">=1.17.3", markers = "python_version > \"2.7\""}
proglog = "<=1.0.0"
requests = ">=2.8.1,<3.0"
tqdm = ">=4.11.2,<5.0"
[package.extras]
doc = ["Sphinx (>=1.5.2,<2.0)", "numpydoc (>=0.6.0,<1.0)", "pygame (>=1.9.3,<2.0)", "sphinx_rtd_theme (>=0.1.10b0,<1.0)"]
optional = ["matplotlib (>=2.0.0,<3.0)", "opencv-python (>=3.0,<4.0)", "scikit-image (>=0.13.0,<1.0)", "scikit-learn", "scipy (>=0.19.0,<1.5)", "youtube_dl"]
test = ["coverage (<5.0)", "coveralls (>=1.1,<2.0)", "pytest (>=3.0.0,<4.0)", "pytest-cov (>=2.5.1,<3.0)", "requests (>=2.8.1,<3.0)"]
[[package]]
name = "mpmath"
version = "1.3.0"
@@ -2696,20 +2710,6 @@ nodeenv = ">=0.11.1"
pyyaml = ">=5.1"
virtualenv = ">=20.10.0"
[[package]]
name = "proglog"
version = "0.1.10"
description = "Log and progress bar manager for console, notebooks, web..."
optional = false
python-versions = "*"
files = [
{file = "proglog-0.1.10-py3-none-any.whl", hash = "sha256:19d5da037e8c813da480b741e3fa71fb1ac0a5b02bf21c41577c7f327485ec50"},
{file = "proglog-0.1.10.tar.gz", hash = "sha256:658c28c9c82e4caeb2f25f488fff9ceace22f8d69b15d0c1c86d64275e4ddab4"},
]
[package.dependencies]
tqdm = "*"
[[package]]
name = "protobuf"
version = "5.27.2"
@@ -3276,6 +3276,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@@ -3809,13 +3810,13 @@ test = ["pytest"]
[[package]]
name = "setuptools"
version = "71.0.1"
version = "71.0.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "setuptools-71.0.1-py3-none-any.whl", hash = "sha256:1eb8ef012efae7f6acbc53ec0abde4bc6746c43087fd215ee09e1df48998711f"},
{file = "setuptools-71.0.1.tar.gz", hash = "sha256:c51d7fd29843aa18dad362d4b4ecd917022131425438251f4e3d766c964dd1ad"},
{file = "setuptools-71.0.0-py3-none-any.whl", hash = "sha256:f06fbe978a91819d250a30e0dc4ca79df713d909e24438a42d0ec300fc52247f"},
{file = "setuptools-71.0.0.tar.gz", hash = "sha256:98da3b8aca443b9848a209ae4165e2edede62633219afa493a58fbba57f72e2e"},
]
[package.extras]
@@ -4215,6 +4216,23 @@ perf = ["orjson"]
sweeps = ["sweeps (>=0.2.0)"]
workspaces = ["wandb-workspaces"]
[[package]]
name = "werkzeug"
version = "3.0.3"
description = "The comprehensive WSGI web application library."
optional = false
python-versions = ">=3.8"
files = [
{file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"},
{file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"},
]
[package.dependencies]
MarkupSafe = ">=2.1.1"
[package.extras]
watchdog = ["watchdog (>=2.3)"]
[[package]]
name = "xxhash"
version = "3.4.1"
@@ -4485,4 +4503,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "dfe9c6a54e0382156e62e7bd2c7aab1be6372da76d30c61b06d27232276638cb"
content-hash = "25d5a270d770d37b13a93bf72868d3b9e683f8af5252b6332ec926a26fd0c096"

View File

@@ -57,13 +57,15 @@ pytest-cov = {version = ">=5.0.0", optional = true}
datasets = ">=2.19.0"
imagecodecs = { version = ">=2024.1.1", optional = true }
pyav = ">=12.0.5"
moviepy = ">=1.0.3"
rerun-sdk = ">=0.15.1"
deepdiff = ">=7.0.1"
scikit-image = {version = ">=0.23.2", optional = true}
flask = ">=3.0.3"
pandas = {version = ">=2.2.2", optional = true}
scikit-image = {version = ">=0.23.2", optional = true}
dynamixel-sdk = {version = ">=3.7.31", optional = true}
pynput = {version = ">=1.7.7", optional = true}
# TODO(rcadene, salibert): 71.0.1 has a bug
setuptools = {version = "!=71.0.1", optional = true}

View File

@@ -13,6 +13,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Save the policy tests artifacts.
Note: Run on the cluster
Example of usage:
```bash
DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py
```
"""
import platform
import shutil
from pathlib import Path
@@ -54,7 +66,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
loss = output_dict["loss"]
loss.backward()
loss.mean().backward()
grad_stats = {}
for key, param in policy.named_parameters():
if param.requires_grad:
@@ -96,10 +108,21 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
print(f" - Validate with: `git add {env_policy_dir}`")
print(f" - Revert with: `git checkout -- {env_policy_dir}`")
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
from safetensors.torch import load_file
if (env_policy_dir / "output_dict.safetensors").exists():
prev_loss = load_file(env_policy_dir / "output_dict.safetensors")["loss"]
print(f"Previous loss={prev_loss}")
print(f"New loss={output_dict['loss'].mean()}")
print()
if env_policy_dir.exists():
shutil.rmtree(env_policy_dir)
env_policy_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
@@ -107,27 +130,32 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
if __name__ == "__main__":
if platform.machine() != "x86_64":
raise OSError("Generate policy artifacts on x86_64 machine since it is used for the unit tests. ")
env_policies = [
# ("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
# ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
# (
# "pusht",
# "diffusion",
# [
# "policy.n_action_steps=8",
# "policy.num_inference_steps=10",
# "policy.down_dims=[128, 256, 512]",
# ],
# "",
# ),
# ("aloha", "act", ["policy.n_action_steps=10"], ""),
# ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
(
"pusht",
"diffusion",
[
"policy.n_action_steps=8",
"policy.num_inference_steps=10",
"policy.down_dims=[128, 256, 512]",
],
"",
),
("aloha", "act", ["policy.n_action_steps=10"], ""),
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
]
if len(env_policies) == 0:
raise RuntimeError("No policies were provided!")
for env, policy, extra_overrides, file_name_extra in env_policies:
print(f"env={env} policy={policy} extra_overrides={extra_overrides}")
save_policy_to_safetensors(
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
)
print()

View File

@@ -147,10 +147,11 @@ def test_policy(env_name, policy_name, extra_overrides):
# Check that we run select_actions and get the appropriate output.
env = make_env(cfg, n_envs=2)
batch_size = 2
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=2,
batch_size=batch_size,
shuffle=True,
pin_memory=DEVICE != "cpu",
drop_last=True,
@@ -164,12 +165,19 @@ def test_policy(env_name, policy_name, extra_overrides):
# Test updating the policy (and test that it does not mutate the batch)
batch_ = deepcopy(batch)
policy.forward(batch)
out = policy.forward(batch)
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
assert all(
torch.equal(batch[k], batch_[k]) for k in batch
), "Batch values are not the same after a forward pass."
# Test loss can be visualized using visualize_dataset_html.py
for key in out:
if "loss" in key:
assert (
out[key].ndim == 1 and out[key].shape[0] == batch_size
), f"1 loss value per item in the batch is expected, but {out[key].shape} provided instead."
# reset the policy and environment
policy.reset()
observation, _ = env.reset(seed=cfg.seed)
@@ -234,6 +242,7 @@ def test_policy_defaults(policy_name: str):
[
("xarm", "tdmpc"),
("pusht", "diffusion"),
("pusht", "vqbet"),
("aloha", "act"),
],
)
@@ -250,7 +259,7 @@ def test_yaml_matches_dataclass(env_name: str, policy_name: str):
def test_save_and_load_pretrained(policy_name: str):
policy_cls, _ = get_policy_and_config_classes(policy_name)
policy: Policy = policy_cls()
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
save_dir = f"/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
policy.save_pretrained(save_dir)
policy_ = policy_cls.from_pretrained(save_dir)
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
@@ -365,6 +374,7 @@ def test_normalize(insert_temporal_dim):
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
"",
),
("pusht", "vqbet", "[]", ""),
("aloha", "act", ["policy.n_action_steps=10"], ""),
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
@@ -461,7 +471,3 @@ def test_act_temporal_ensembler():
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
if __name__ == "__main__":
test_act_temporal_ensembler()

View File

@@ -25,13 +25,13 @@ from lerobot.scripts.visualize_dataset import visualize_dataset
["lerobot/pusht"],
)
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
def test_visualize_local_dataset(tmpdir, repo_id, root):
def test_visualize_dataset_root(tmpdir, repo_id, root):
rrd_path = visualize_dataset(
repo_id,
root=root,
episode_index=0,
batch_size=32,
save=True,
output_dir=tmpdir,
root=root,
)
assert rrd_path.exists()

View File

@@ -0,0 +1,72 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import pytest
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.logger import Logger
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
from tests.utils import DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
"repo_id",
["lerobot/pusht"],
)
def test_visualize_dataset_html(tmpdir, repo_id):
tmpdir = Path(tmpdir)
visualize_dataset_html(
repo_id,
episodes=[0],
output_dir=tmpdir,
serve=False,
)
assert (tmpdir / "static" / "episode_0.csv").exists()
@pytest.mark.parametrize(
"repo_id, policy_method",
[
("lerobot/pusht", "select_action"),
("lerobot/pusht", "forward"),
],
)
def test_visualize_dataset_policy_ckpt_path(tmpdir, repo_id, policy_method):
tmpdir = Path(tmpdir)
# Create a policy
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=["device=cpu"])
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
# Save a checkpoint
logger = Logger(cfg, tmpdir)
logger.save_model(tmpdir, policy)
visualize_dataset_html(
repo_id,
episodes=[0],
output_dir=tmpdir,
serve=False,
pretrained_policy_name_or_path=tmpdir,
policy_method=policy_method,
)
assert (tmpdir / "static" / "episode_0.csv").exists()
assert (tmpdir / "episode_0.safetensors").exists()