Compare commits
4 Commits
smolvla_do
...
user/rcade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9ddbbd8e80 | ||
|
|
1da5caaf4b | ||
|
|
ca7f207d74 | ||
|
|
6b9dcadbf7 |
@@ -44,7 +44,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: Path | None = DATA_DIR,
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
@@ -53,22 +53,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = root
|
||||
if self.root is None and DATA_DIR is not None:
|
||||
self.root = DATA_DIR
|
||||
self.split = split
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# load data from hub or locally when root is provided
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
|
||||
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, self.root, split)
|
||||
if split == "train":
|
||||
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
|
||||
else:
|
||||
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
||||
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
||||
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
|
||||
self.info = load_info(repo_id, CODEBASE_VERSION, root)
|
||||
self.stats = load_stats(repo_id, CODEBASE_VERSION, self.root)
|
||||
self.info = load_info(repo_id, CODEBASE_VERSION, self.root)
|
||||
if self.video:
|
||||
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
|
||||
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, self.root)
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
|
||||
@property
|
||||
|
||||
@@ -233,9 +233,6 @@ class Logger:
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
|
||||
@@ -134,25 +134,26 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
bsize = actions_hat.shape[0]
|
||||
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
|
||||
|
||||
out_dict = {}
|
||||
out_dict["l1_loss"] = l1_loss
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
if self.config.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
|
||||
out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
out_dict["loss"] = l1_loss
|
||||
|
||||
return loss_dict
|
||||
out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"]
|
||||
return out_dict
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
|
||||
@@ -341,7 +341,11 @@ class DiffusionModel(nn.Module):
|
||||
in_episode_bound = ~batch["action_is_pad"]
|
||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
||||
# Compute average per item in the batch
|
||||
bsize = loss.shape[0]
|
||||
loss = loss.reshape(bsize, -1).mean(1)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
|
||||
@@ -396,51 +396,39 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
|
||||
# predicted from the (target model's) observation encoder.
|
||||
consistency_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
|
||||
# `z_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# `z_targets` depends on the next observation.
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
|
||||
# `z_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# `z_targets` depends on the next observation.
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
).sum(0)
|
||||
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
|
||||
# rewards.
|
||||
reward_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(reward_preds, reward, reduction="none")
|
||||
* ~batch["next.reward_is_pad"]
|
||||
# `reward_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(reward_preds, reward, reduction="none")
|
||||
* ~batch["next.reward_is_pad"]
|
||||
# `reward_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).sum(0)
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
q_value_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(
|
||||
q_preds_ensemble,
|
||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(
|
||||
q_preds_ensemble,
|
||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
).sum(0)
|
||||
# Compute state value loss as in eqn 3 of FOWM.
|
||||
diff = v_targets - v_preds
|
||||
# Expectile loss penalizes:
|
||||
@@ -450,16 +438,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight)
|
||||
) * (diff**2)
|
||||
v_value_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* raw_v_value_loss
|
||||
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
temporal_loss_coeffs
|
||||
* raw_v_value_loss
|
||||
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).sum(0)
|
||||
|
||||
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
||||
# We won't need these gradients again so detach.
|
||||
@@ -492,7 +476,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
# `action_preds` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).mean()
|
||||
).sum(0)
|
||||
|
||||
loss = (
|
||||
self.config.consistency_coeff * consistency_loss
|
||||
@@ -504,13 +488,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
|
||||
info.update(
|
||||
{
|
||||
"consistency_loss": consistency_loss.item(),
|
||||
"reward_loss": reward_loss.item(),
|
||||
"Q_value_loss": q_value_loss.item(),
|
||||
"V_value_loss": v_value_loss.item(),
|
||||
"pi_loss": pi_loss.item(),
|
||||
"consistency_loss": consistency_loss,
|
||||
"reward_loss": reward_loss,
|
||||
"Q_value_loss": q_value_loss,
|
||||
"V_value_loss": v_value_loss,
|
||||
"pi_loss": pi_loss,
|
||||
"loss": loss,
|
||||
"sum_loss": loss.item() * self.config.horizon,
|
||||
"sum_loss": loss * self.config.horizon,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from torch import nn
|
||||
|
||||
|
||||
@@ -47,3 +53,26 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
||||
Note: assumes that all parameters have the same dtype.
|
||||
"""
|
||||
return next(iter(module.parameters())).dtype
|
||||
|
||||
|
||||
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
||||
try:
|
||||
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||
)
|
||||
else:
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||
)
|
||||
|
||||
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||
pretrained_policy_path = Path(pretrained_policy_name_or_path)
|
||||
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||
raise ValueError(
|
||||
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
return pretrained_policy_path
|
||||
|
||||
@@ -24,7 +24,7 @@ training:
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
||||
@@ -50,7 +50,7 @@ training:
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
||||
@@ -48,7 +48,7 @@ training:
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
||||
@@ -56,9 +56,6 @@ import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
@@ -68,7 +65,7 @@ from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_pretrained_policy_path
|
||||
from lerobot.common.utils.io_utils import write_video
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
|
||||
@@ -501,29 +498,6 @@ def main(
|
||||
logging.info("End of eval")
|
||||
|
||||
|
||||
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
||||
try:
|
||||
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||
)
|
||||
else:
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||
)
|
||||
|
||||
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||
pretrained_policy_path = Path(pretrained_policy_name_or_path)
|
||||
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||
raise ValueError(
|
||||
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
return pretrained_policy_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
|
||||
@@ -120,8 +120,7 @@ def update_policy(
|
||||
policy.train()
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
loss = output_dict["loss"]
|
||||
loss = output_dict["loss"].mean()
|
||||
grad_scaler.scale(loss).backward()
|
||||
|
||||
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
|
||||
@@ -150,14 +149,12 @@ def update_policy(
|
||||
policy.update()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": optimizer.param_groups[0]["lr"],
|
||||
"update_s": time.perf_counter() - start_time,
|
||||
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||
**{k: v.detach().mean().item() for k, v in output_dict.items() if "loss" in k},
|
||||
**{k: v for k, v in output_dict.items() if "loss" not in k},
|
||||
}
|
||||
info.update({k: v for k, v in output_dict.items() if k not in info})
|
||||
|
||||
return info
|
||||
|
||||
|
||||
|
||||
@@ -108,8 +108,8 @@ def visualize_dataset(
|
||||
web_port: int = 9090,
|
||||
ws_port: int = 9087,
|
||||
save: bool = False,
|
||||
output_dir: Path | None = None,
|
||||
root: Path | None = None,
|
||||
output_dir: Path | None = None,
|
||||
) -> Path | None:
|
||||
if save:
|
||||
assert (
|
||||
@@ -209,6 +209,18 @@ def main():
|
||||
required=True,
|
||||
help="Episode to visualize.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Directory path to write a .rrd file when `--save 1` is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
@@ -254,17 +266,6 @@ def main():
|
||||
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
help="Directory path to write a .rrd file when `--save 1` is set.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
help="Root directory for a dataset stored on a local machine.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
visualize_dataset(**vars(args))
|
||||
|
||||
467
lerobot/scripts/visualize_dataset_html.py
Normal file
467
lerobot/scripts/visualize_dataset_html.py
Normal file
@@ -0,0 +1,467 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
||||
|
||||
Note: The last frame of the episode doesnt always correspond to a final state.
|
||||
That's because our datasets are composed of transition from state to state up to
|
||||
the antepenultimate state associated to the ultimate action to arrive in the final state.
|
||||
However, there might not be a transition from a final state to another state.
|
||||
|
||||
Note: This script aims to visualize the data used to train the neural networks.
|
||||
~What you see is what you get~. When visualizing image modality, it is often expected to observe
|
||||
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
|
||||
save disk space. The compression factor applied has been tuned to not affect success rate.
|
||||
|
||||
Example of usage:
|
||||
|
||||
- Visualize data stored on a local machine:
|
||||
```bash
|
||||
local$ python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht
|
||||
|
||||
local$ open http://localhost:9090
|
||||
```
|
||||
|
||||
- Visualize data stored on a distant machine with a local viewer:
|
||||
```bash
|
||||
distant$ python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht
|
||||
|
||||
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
|
||||
local$ open http://localhost:9090
|
||||
```
|
||||
|
||||
- Select episodes to visualize:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--episodes 7 3 5 1 4
|
||||
```
|
||||
|
||||
- Run inference of a policy on the dataset and visualize the results:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--episodes 7 3 5 1 4
|
||||
-p lerobot/diffusion_pusht \
|
||||
--policy-overrides device=cpu
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from flask import Flask, redirect, render_template, url_for
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.utils import get_pretrained_policy_path
|
||||
from lerobot.common.utils.utils import init_hydra_config, init_logging
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, episode_index):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frame_ids)
|
||||
|
||||
|
||||
def run_server(
|
||||
dataset: LeRobotDataset,
|
||||
episodes: list[int],
|
||||
host: str,
|
||||
port: str,
|
||||
static_folder: Path,
|
||||
template_folder: Path,
|
||||
has_policy: bool = False,
|
||||
):
|
||||
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
# home page redirects to the first episode page
|
||||
[dataset_namespace, dataset_name] = dataset.repo_id.split("/")
|
||||
first_episode_id = episodes[0]
|
||||
return redirect(
|
||||
url_for(
|
||||
"show_episode",
|
||||
dataset_namespace=dataset_namespace,
|
||||
dataset_name=dataset_name,
|
||||
episode_id=first_episode_id,
|
||||
)
|
||||
)
|
||||
|
||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id):
|
||||
dataset_info = {
|
||||
"repo_id": dataset.repo_id,
|
||||
"num_samples": dataset.num_samples,
|
||||
"num_episodes": dataset.num_episodes,
|
||||
"fps": dataset.fps,
|
||||
}
|
||||
video_paths = get_episode_video_paths(dataset, episode_id)
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": Path(video_path).name}
|
||||
for video_path in video_paths
|
||||
]
|
||||
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
|
||||
return render_template(
|
||||
"visualize_dataset_template.html",
|
||||
episode_id=episode_id,
|
||||
episodes=episodes,
|
||||
dataset_info=dataset_info,
|
||||
videos_info=videos_info,
|
||||
ep_csv_url=ep_csv_url,
|
||||
has_policy=has_policy,
|
||||
)
|
||||
|
||||
app.run(host=host, port=port)
|
||||
|
||||
|
||||
def get_ep_csv_fname(episode_id: int):
|
||||
ep_csv_fname = f"episode_{episode_id}.csv"
|
||||
return ep_csv_fname
|
||||
|
||||
|
||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None):
|
||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||
|
||||
has_state = "observation.state" in dataset.hf_dataset.features
|
||||
has_action = "action" in dataset.hf_dataset.features
|
||||
has_inference = inference_results is not None
|
||||
|
||||
# init header of csv with state and action names
|
||||
header = ["timestamp"]
|
||||
if has_state:
|
||||
dim_state = len(dataset.hf_dataset["observation.state"][0])
|
||||
header += [f"state_{i}" for i in range(dim_state)]
|
||||
if has_action:
|
||||
dim_action = len(dataset.hf_dataset["action"][0])
|
||||
header += [f"action_{i}" for i in range(dim_action)]
|
||||
if has_inference:
|
||||
if "action" in inference_results:
|
||||
dim_pred_action = inference_results["action"].shape[1]
|
||||
header += [f"pred_action_{i}" for i in range(dim_pred_action)]
|
||||
for key in inference_results:
|
||||
if "loss" in key:
|
||||
header += [key]
|
||||
|
||||
columns = ["timestamp"]
|
||||
if has_state:
|
||||
columns += ["observation.state"]
|
||||
if has_action:
|
||||
columns += ["action"]
|
||||
|
||||
rows = []
|
||||
data = dataset.hf_dataset.select_columns(columns)
|
||||
for i in range(from_idx, to_idx):
|
||||
row = [data[i]["timestamp"].item()]
|
||||
if has_state:
|
||||
row += data[i]["observation.state"].tolist()
|
||||
if has_action:
|
||||
row += data[i]["action"].tolist()
|
||||
rows.append(row)
|
||||
|
||||
if has_inference:
|
||||
num_frames = len(rows)
|
||||
if "action" in inference_results:
|
||||
assert num_frames == inference_results["action"].shape[0]
|
||||
for i in range(num_frames):
|
||||
rows[i] += inference_results["action"][i].tolist()
|
||||
for key in inference_results:
|
||||
if "loss" in key:
|
||||
assert num_frames == inference_results[key].shape[0]
|
||||
for i in range(num_frames):
|
||||
rows[i] += [inference_results[key][i].item()]
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_dir / file_name, "w") as f:
|
||||
f.write(",".join(header) + "\n")
|
||||
for row in rows:
|
||||
row_str = [str(col) for col in row]
|
||||
f.write(",".join(row_str) + "\n")
|
||||
|
||||
|
||||
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
# get first frame of episode (hack to get video_path of the episode)
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
return [
|
||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
||||
for key in dataset.video_frame_keys
|
||||
]
|
||||
|
||||
|
||||
def run_inference(
|
||||
dataset, episode_index, policy, policy_method="select_action", num_workers=4, batch_size=32, device="cuda"
|
||||
):
|
||||
if policy_method not in ["select_action", "forward"]:
|
||||
raise ValueError(
|
||||
f"`policy_method` is expected to be 'select_action' or 'forward', but '{policy_method}' is provided instead."
|
||||
)
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
logging.info("Loading dataloader")
|
||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
# When using `select_action`, we set batch size 1 so that we feed 1 frame at a time, in a continuous fashion.
|
||||
batch_size=1 if policy_method == "select_action" else batch_size,
|
||||
sampler=episode_sampler,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
warned_ndim_eq_0 = False
|
||||
warned_ndim_gt_2 = False
|
||||
|
||||
logging.info("Running inference")
|
||||
inference_results = {}
|
||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
with torch.inference_mode():
|
||||
if policy_method == "select_action":
|
||||
gt_action = batch.pop("action")
|
||||
output_dict = {"action": policy.select_action(batch)}
|
||||
batch["action"] = gt_action
|
||||
elif policy_method == "forward":
|
||||
output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): Save and display all predicted actions at a given timestamp
|
||||
# Save predicted action for the next timestamp only
|
||||
output_dict["action"] = output_dict["action"][:, 0, :]
|
||||
|
||||
for key in output_dict:
|
||||
if output_dict[key].ndim == 0:
|
||||
if not warned_ndim_eq_0:
|
||||
warnings.warn(
|
||||
f"Ignore output key '{key}'. Its value is a scalar instead of a vector. It might have been aggregated over the batch dimension (e.g. `loss.mean()`).",
|
||||
stacklevel=1,
|
||||
)
|
||||
warned_ndim_eq_0 = True
|
||||
continue
|
||||
|
||||
if output_dict[key].ndim > 2:
|
||||
if not warned_ndim_gt_2:
|
||||
warnings.warn(
|
||||
f"Ignore output key '{key}'. Its value is a tensor of {output_dict[key].ndim} dimensions instead of a vector.",
|
||||
stacklevel=1,
|
||||
)
|
||||
warned_ndim_gt_2 = True
|
||||
continue
|
||||
|
||||
if key not in inference_results:
|
||||
inference_results[key] = []
|
||||
inference_results[key].append(output_dict[key].to("cpu"))
|
||||
|
||||
for key in inference_results:
|
||||
inference_results[key] = torch.cat(inference_results[key])
|
||||
|
||||
return inference_results
|
||||
|
||||
|
||||
def visualize_dataset_html(
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
episodes: list[int] = None,
|
||||
output_dir: Path | None = None,
|
||||
serve: bool = True,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 9090,
|
||||
force_override: bool = False,
|
||||
policy_method: str = "select_action",
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: list[str] | None = None,
|
||||
) -> Path | None:
|
||||
init_logging()
|
||||
|
||||
has_policy = pretrained_policy_name_or_path is not None
|
||||
|
||||
if has_policy:
|
||||
logging.info("Loading policy")
|
||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||
dataset = make_dataset(hydra_cfg)
|
||||
policy = make_policy(hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
|
||||
if policy_method == "select_action":
|
||||
# Do not load previous observations or future actions, to simulate that the observations come from
|
||||
# an environment.
|
||||
dataset.delta_timestamps = None
|
||||
else:
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
if not dataset.video:
|
||||
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
||||
|
||||
if has_policy:
|
||||
ckpt_str = pretrained_policy_path.parts[-2]
|
||||
exp_name = pretrained_policy_path.parts[-4]
|
||||
output_dir += f"_{exp_name}_{ckpt_str}_{policy_method}"
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists():
|
||||
if force_override:
|
||||
shutil.rmtree(output_dir)
|
||||
else:
|
||||
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
||||
# so that the http server can get access to the mp4 files.
|
||||
static_dir = output_dir / "static"
|
||||
static_dir.mkdir(parents=True, exist_ok=True)
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
|
||||
|
||||
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
||||
|
||||
if episodes is None:
|
||||
episodes = list(range(dataset.num_episodes))
|
||||
|
||||
logging.info("Writing CSV files")
|
||||
for episode_index in tqdm.tqdm(episodes):
|
||||
inference_results = None
|
||||
if has_policy:
|
||||
inference_results_path = output_dir / f"episode_{episode_index}.safetensors"
|
||||
if inference_results_path.exists():
|
||||
inference_results = load_file(inference_results_path)
|
||||
else:
|
||||
inference_results = run_inference(
|
||||
dataset,
|
||||
episode_index,
|
||||
policy,
|
||||
policy_method,
|
||||
num_workers=hydra_cfg.training.num_workers,
|
||||
batch_size=hydra_cfg.training.batch_size,
|
||||
device=hydra_cfg.device,
|
||||
)
|
||||
inference_results_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_file(inference_results, inference_results_path)
|
||||
|
||||
# write states and actions in a csv (it can be slow for big datasets)
|
||||
ep_csv_fname = get_ep_csv_fname(episode_index)
|
||||
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, inference_results)
|
||||
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir, has_policy)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serve",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch web server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Web host used by the http server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=9090,
|
||||
help="Web port used by the http server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-override",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Delete the output directory if it exists already.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--policy-method",
|
||||
type=str,
|
||||
default="select_action",
|
||||
choices=["select_action", "forward"],
|
||||
help="Python method used to run the inference. By default, set to `select_action` used during evaluation to output the sequence of actions. Can bet set to `forward` used during training to compute the loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
type=str,
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy-overrides",
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override policy config values (use dots for.nested=overrides)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
visualize_dataset_html(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
360
lerobot/templates/visualize_dataset_template.html
Normal file
360
lerobot/templates/visualize_dataset_template.html
Normal file
@@ -0,0 +1,360 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<!-- # TODO(rcadene, mishig25): store the js files locally -->
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/alpinejs/3.13.5/cdn.min.js" defer></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/dygraphs@2.2.1/dist/dygraph.min.js" type="text/javascript"></script>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<title>{{ dataset_info.repo_id }} episode {{ episode_id }}</title>
|
||||
</head>
|
||||
|
||||
<!-- Use [Alpin.js](https://alpinejs.dev), a lightweight and easy to learn JS framework -->
|
||||
<!-- Use [tailwindcss](https://tailwindcss.com/), CSS classes for styling html -->
|
||||
<!-- Use [dygraphs](https://dygraphs.com/), a lightweight JS charting library -->
|
||||
<body class="flex h-screen max-h-screen bg-slate-950 text-gray-200" x-data="createAlpineData()" @keydown.window="(e) => {
|
||||
// Use the space bar to play and pause, instead of default action (e.g. scrolling)
|
||||
const { keyCode, key } = e;
|
||||
if (keyCode === 32 || key === ' ') {
|
||||
e.preventDefault();
|
||||
$refs.btnPause.classList.contains('hidden') ? $refs.btnPlay.click() : $refs.btnPause.click();
|
||||
}else if (key === 'ArrowDown' || key === 'ArrowUp'){
|
||||
const nextEpisodeId = key === 'ArrowDown' ? {{ episode_id }} + 1 : {{ episode_id }} - 1;
|
||||
const lowestEpisodeId = {{ episodes }}.at(0);
|
||||
const highestEpisodeId = {{ episodes }}.at(-1);
|
||||
if(nextEpisodeId >= lowestEpisodeId && nextEpisodeId <= highestEpisodeId){
|
||||
window.location.href = `./episode_${nextEpisodeId}`;
|
||||
}
|
||||
}
|
||||
}">
|
||||
<!-- Sidebar -->
|
||||
<div x-ref="sidebar" class="w-60 bg-slate-900 p-5 break-words max-h-screen overflow-y-auto">
|
||||
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
|
||||
|
||||
<ul>
|
||||
<li>
|
||||
Number of samples/frames: {{ dataset_info.num_samples }}
|
||||
</li>
|
||||
<li>
|
||||
Number of episodes: {{ dataset_info.num_episodes }}
|
||||
</li>
|
||||
<li>
|
||||
Frames per second: {{ dataset_info.fps }}
|
||||
</li>
|
||||
</ul>
|
||||
|
||||
<p>Episodes:</p>
|
||||
<ul class="ml-2">
|
||||
{% for episode in episodes %}
|
||||
<li class="font-mono text-sm mt-0.5">
|
||||
<a href="episode_{{ episode }}" class="underline {% if episode_id == episode %}font-bold -ml-1{% endif %}">
|
||||
Episode {{ episode }}
|
||||
</a>
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
|
||||
</div>
|
||||
|
||||
<!-- Toggle sidebar button -->
|
||||
<button class="flex items-center opacity-50 hover:opacity-100 mx-1"
|
||||
@click="() => ($refs.sidebar.classList.toggle('hidden'))" title="Toggle sidebar">
|
||||
<div class="bg-slate-500 w-2 h-10 rounded-full"></div>
|
||||
</button>
|
||||
|
||||
<!-- Content -->
|
||||
<div class="flex-1 max-h-screen flex flex-col gap-4 overflow-y-auto">
|
||||
<h1 class="text-xl font-bold mt-4 font-mono">
|
||||
Episode {{ episode_id }}
|
||||
</h1>
|
||||
|
||||
<!-- Videos -->
|
||||
<div class="flex flex-wrap gap-1">
|
||||
{% for video_info in videos_info %}
|
||||
<div class="max-w-96">
|
||||
<p class="text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
|
||||
<video autoplay muted loop type="video/mp4" class="min-w-64" @timeupdate="() => {
|
||||
if (video.duration) {
|
||||
const time = video.currentTime;
|
||||
const pc = (100 / video.duration) * time;
|
||||
$refs.slider.value = pc;
|
||||
dygraphTime = time;
|
||||
dygraphIndex = Math.floor(pc * dygraph.numRows() / 100);
|
||||
dygraph.setSelection(dygraphIndex, undefined, true, true);
|
||||
|
||||
$refs.timer.textContent = formatTime(time) + ' / ' + formatTime(video.duration);
|
||||
|
||||
updateTimeQuery(time.toFixed(2));
|
||||
}
|
||||
}" @ended="() => {
|
||||
$refs.btnPlay.classList.remove('hidden');
|
||||
$refs.btnPause.classList.add('hidden');
|
||||
}"
|
||||
@loadedmetadata="() => ($refs.timer.textContent = formatTime(0) + ' / ' + formatTime(video.duration))">
|
||||
<source src="{{ video_info.url }}">
|
||||
Your browser does not support the video tag.
|
||||
</video>
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
|
||||
<!-- Shortcuts info -->
|
||||
<div class="text-sm hidden md:block">
|
||||
Hotkeys: <span class="font-mono">Space</span> to pause/unpause, <span class="font-mono">Arrow Down</span> to go to next episode, <span class="font-mono">Arrow Up</span> to go to previous episode.
|
||||
</div>
|
||||
|
||||
<!-- Controllers -->
|
||||
<div class="flex gap-1 text-3xl items-center">
|
||||
<button x-ref="btnPlay" class="-rotate-90 hidden" class="-rotate-90" title="Play. Toggle with Space" @click="() => {
|
||||
videos.forEach(video => video.play());
|
||||
$refs.btnPlay.classList.toggle('hidden');
|
||||
$refs.btnPause.classList.toggle('hidden');
|
||||
}">🔽</button>
|
||||
<button x-ref="btnPause" title="Pause. Toggle with Space" @click="() => {
|
||||
videos.forEach(video => video.pause());
|
||||
$refs.btnPlay.classList.toggle('hidden');
|
||||
$refs.btnPause.classList.toggle('hidden');
|
||||
}">⏸️</button>
|
||||
<button title="Jump backward 5 seconds"
|
||||
@click="() => (videos.forEach(video => (video.currentTime -= 5)))">⏪</button>
|
||||
<button title="Jump forward 5 seconds"
|
||||
@click="() => (videos.forEach(video => (video.currentTime += 5)))">⏩</button>
|
||||
<button title="Rewind from start"
|
||||
@click="() => (videos.forEach(video => (video.currentTime = 0.0)))">↩️</button>
|
||||
<input x-ref="slider" max="100" min="0" step="1" type="range" value="0" class="w-80 mx-2" @input="() => {
|
||||
const sliderValue = $refs.slider.value;
|
||||
$refs.btnPause.click();
|
||||
videos.forEach(video => {
|
||||
const time = (video.duration * sliderValue) / 100;
|
||||
video.currentTime = time;
|
||||
});
|
||||
}" />
|
||||
<div x-ref="timer" class="font-mono text-sm border border-slate-500 rounded-lg px-1 py-0.5 shrink-0">0:00 /
|
||||
0:00
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Graph -->
|
||||
<div class="flex gap-2 mb-4 flex-wrap">
|
||||
<div>
|
||||
<div id="graph" @mouseleave="() => {
|
||||
dygraph.setSelection(dygraphIndex, undefined, true, true);
|
||||
dygraphTime = video.currentTime;
|
||||
}">
|
||||
</div>
|
||||
<p x-ref="graphTimer" class="font-mono ml-14 mt-4"
|
||||
x-init="$watch('dygraphTime', value => ($refs.graphTimer.innerText = `Time: ${dygraphTime.toFixed(2)}s`))">
|
||||
Time: 0.00s
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<table class="text-sm border-collapse border border-slate-700" x-show="currentFrameData">
|
||||
<thead>
|
||||
<tr>
|
||||
<th></th>
|
||||
<template x-for="(_, colIndex) in Array.from({length: nColumns}, (_, index) => index)">
|
||||
<th class="border border-slate-700">
|
||||
<div class="flex gap-x-2 justify-between px-2">
|
||||
<input type="checkbox" :checked="isColumnChecked(colIndex)"
|
||||
@change="toggleColumn(colIndex)">
|
||||
<p x-text="`${columnNames[colIndex]}`"></p>
|
||||
</div>
|
||||
</th>
|
||||
</template>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<template x-for="(row, rowIndex) in rows">
|
||||
<tr class="odd:bg-gray-800 even:bg-gray-900">
|
||||
<td class="border border-slate-700">
|
||||
<div class="flex gap-x-2 w-24 font-semibold px-1">
|
||||
<input type="checkbox" :checked="isRowChecked(rowIndex)"
|
||||
@change="toggleRow(rowIndex)">
|
||||
<p x-text="`Motor ${rowIndex}`"></p>
|
||||
</div>
|
||||
</td>
|
||||
<template x-for="(cell, colIndex) in row">
|
||||
<td x-show="cell" class="border border-slate-700">
|
||||
<div class="flex gap-x-2 w-24 justify-between px-2">
|
||||
<input type="checkbox" x-model="cell.checked" @change="updateTableValues()">
|
||||
<span x-text="`${cell.value.toFixed(2)}`"
|
||||
:style="`color: ${cell.color}`"></span>
|
||||
</div>
|
||||
</td>
|
||||
</template>
|
||||
</tr>
|
||||
</template>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<div id="labels" class="hidden">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function createAlpineData() {
|
||||
return {
|
||||
// state
|
||||
dygraph: null,
|
||||
currentFrameData: null,
|
||||
columnNames: ["state", "action", "pred action"],
|
||||
nColumns: {% if has_policy %}3{% else %}2{% endif %},
|
||||
checked: [],
|
||||
dygraphTime: 0.0,
|
||||
dygraphIndex: 0,
|
||||
videos: null,
|
||||
video: null,
|
||||
colors: null,
|
||||
|
||||
// alpine initialization
|
||||
init() {
|
||||
this.videos = document.querySelectorAll('video');
|
||||
this.video = this.videos[0];
|
||||
this.dygraph = new Dygraph(document.getElementById("graph"), '{{ ep_csv_url }}', {
|
||||
pixelsPerPoint: 0.01,
|
||||
legend: 'always',
|
||||
labelsDiv: document.getElementById('labels'),
|
||||
labelsKMB: true,
|
||||
strokeWidth: 1.5,
|
||||
pointClickCallback: (event, point) => {
|
||||
this.dygraphTime = point.xval;
|
||||
this.updateTableValues(this.dygraphTime);
|
||||
},
|
||||
highlightCallback: (event, x, points, row, seriesName) => {
|
||||
this.dygraphTime = x;
|
||||
this.updateTableValues(this.dygraphTime);
|
||||
},
|
||||
drawCallback: (dygraph, is_initial) => {
|
||||
if (is_initial) {
|
||||
// dygraph initialization
|
||||
this.dygraph.setSelection(this.dygraphIndex, undefined, true, true);
|
||||
this.colors = this.dygraph.getColors();
|
||||
this.checked = Array(this.colors.length).fill(true);
|
||||
|
||||
const seriesNames = this.dygraph.getLabels().slice(1);
|
||||
const colors = [];
|
||||
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
|
||||
let lightnessIdx = 0;
|
||||
const chunkSize = Math.ceil(seriesNames.length / this.nColumns);
|
||||
for (let i = 0; i < seriesNames.length; i += chunkSize) {
|
||||
const lightness = LIGHTNESS[lightnessIdx];
|
||||
for (let hue = 0; hue < 360; hue += parseInt(360/chunkSize)) {
|
||||
const color = `hsl(${hue}, 100%, ${lightness}%)`;
|
||||
colors.push(color);
|
||||
}
|
||||
lightnessIdx += 1;
|
||||
}
|
||||
this.dygraph.updateOptions({ colors });
|
||||
this.colors = colors;
|
||||
|
||||
this.updateTableValues();
|
||||
|
||||
let url = new URL(window.location.href);
|
||||
let params = new URLSearchParams(url.search);
|
||||
let time = params.get("t");
|
||||
if(time){
|
||||
time = parseFloat(time);
|
||||
this.videos.forEach(video => (video.currentTime = time));
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
},
|
||||
|
||||
//#region Table Data
|
||||
|
||||
// turn dygraph's 1D data (at a given time t) to 2D data that whose columns names are defined in this.columnNames.
|
||||
// 2d data view is used to create html table element.
|
||||
get rows() {
|
||||
if (!this.currentFrameData) {
|
||||
return [];
|
||||
}
|
||||
const columnSize = Math.ceil(this.currentFrameData.length / this.nColumns);
|
||||
return Array.from({
|
||||
length: columnSize
|
||||
}, (_, rowIndex) => {
|
||||
const row = [
|
||||
this.currentFrameData[rowIndex] || null,
|
||||
this.currentFrameData[rowIndex + columnSize] || null,
|
||||
];
|
||||
if (this.nColumns === 3) {
|
||||
row.push(this.currentFrameData[rowIndex + 2 * columnSize] || null)
|
||||
}
|
||||
return row;
|
||||
});
|
||||
},
|
||||
isRowChecked(rowIndex) {
|
||||
return this.rows[rowIndex].every(cell => cell && cell.checked);
|
||||
},
|
||||
isColumnChecked(colIndex) {
|
||||
return this.rows.every(row => row[colIndex] && row[colIndex].checked);
|
||||
},
|
||||
toggleRow(rowIndex) {
|
||||
const newState = !this.isRowChecked(rowIndex);
|
||||
this.rows[rowIndex].forEach(cell => {
|
||||
if (cell) cell.checked = newState;
|
||||
});
|
||||
this.updateTableValues();
|
||||
},
|
||||
toggleColumn(colIndex) {
|
||||
const newState = !this.isColumnChecked(colIndex);
|
||||
this.rows.forEach(row => {
|
||||
if (row[colIndex]) row[colIndex].checked = newState;
|
||||
});
|
||||
this.updateTableValues();
|
||||
},
|
||||
|
||||
// given time t, update the values in the html table with "data[t]"
|
||||
updateTableValues(time) {
|
||||
if (!this.colors) {
|
||||
return;
|
||||
}
|
||||
let pc = (100 / this.video.duration) * (time === undefined ? this.video.currentTime : time);
|
||||
if (isNaN(pc)) pc = 0;
|
||||
const index = Math.floor(pc * this.dygraph.numRows() / 100);
|
||||
// slice(1) to remove the timestamp point that we do not need
|
||||
const labels = this.dygraph.getLabels().slice(1);
|
||||
const values = this.dygraph.rawData_[index].slice(1);
|
||||
const checkedNew = this.currentFrameData ? this.currentFrameData.map(cell => cell.checked) : Array(
|
||||
this.colors.length).fill(true);
|
||||
this.currentFrameData = labels.map((label, idx) => ({
|
||||
label,
|
||||
value: values[idx],
|
||||
color: this.colors[idx],
|
||||
checked: checkedNew[idx],
|
||||
}));
|
||||
const shouldUpdateVisibility = !this.checked.every((value, index) => value === checkedNew[index]);
|
||||
if (shouldUpdateVisibility) {
|
||||
this.checked = checkedNew;
|
||||
this.dygraph.setVisibility(this.checked);
|
||||
}
|
||||
},
|
||||
|
||||
//#endregion
|
||||
|
||||
updateTimeQuery(time) {
|
||||
let url = new URL(window.location.href);
|
||||
let params = new URLSearchParams(url.search);
|
||||
params.set("t", time);
|
||||
url.search = params.toString();
|
||||
window.history.replaceState({}, '', url.toString());
|
||||
},
|
||||
|
||||
|
||||
formatTime(time) {
|
||||
var hours = Math.floor(time / 3600);
|
||||
var minutes = Math.floor((time % 3600) / 60);
|
||||
var seconds = Math.floor(time % 60);
|
||||
return (hours > 0 ? hours + ':' : '') + (minutes < 10 ? '0' + minutes : minutes) + ':' + (seconds <
|
||||
10 ?
|
||||
'0' + seconds : seconds);
|
||||
}
|
||||
};
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
126
poetry.lock
generated
126
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
@@ -192,6 +192,17 @@ charset-normalizer = ["charset-normalizer"]
|
||||
html5lib = ["html5lib"]
|
||||
lxml = ["lxml"]
|
||||
|
||||
[[package]]
|
||||
name = "blinker"
|
||||
version = "1.8.2"
|
||||
description = "Fast, simple object-to-object and broadcast signaling"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "blinker-1.8.2-py3-none-any.whl", hash = "sha256:1779309f71bf239144b9399d06ae925637cf6634cf6bd131104184531bf67c01"},
|
||||
{file = "blinker-1.8.2.tar.gz", hash = "sha256:8f77b09d3bf7c795e969e9486f39c2c5e9c39d4ee07424be2bc594ece9642d83"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2024.7.4"
|
||||
@@ -584,17 +595,6 @@ files = [
|
||||
{file = "debugpy-1.8.2.zip", hash = "sha256:95378ed08ed2089221896b9b3a8d021e642c24edc8fef20e5d4342ca8be65c00"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "decorator"
|
||||
version = "4.4.2"
|
||||
description = "Decorators for Humans"
|
||||
optional = false
|
||||
python-versions = ">=2.6, !=3.0.*, !=3.1.*"
|
||||
files = [
|
||||
{file = "decorator-4.4.2-py2.py3-none-any.whl", hash = "sha256:41fa54c2a0cc4ba648be4fd43cff00aedf5b9465c9bf18d64325bc225f08f760"},
|
||||
{file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deepdiff"
|
||||
version = "7.0.1"
|
||||
@@ -795,6 +795,7 @@ files = [
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:01f811d0c6722f74743c153a7be0144686daeafa968c473e60f6b6c5dc8f5bff"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:a36e97d31eeb66e6d5913130695d188ceee1248029961012a8b4f59fd3f58670"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25d620123a733661dc740ef2b456601ddbaa69ae2b50d8141daa3c684bda385c"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a9fdc4e73578bebb1c8d0f8bea2243a5a9e179f08c74d98576123b59b75e5cac"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e65830634c58158557f0ab90e5d1f492bcbc6b74587b05825ba4c20b634dc1bd"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c01f9ab8f93295341aeab2d606d484d9cff9d05f57581e2180433ec8e0d38307"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5d6d46a49a34cd7e4f74496a1089b9a1b78282c219a28d98fe031a763e92d530"},
|
||||
@@ -892,6 +893,28 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1
|
||||
testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"]
|
||||
typing = ["typing-extensions (>=4.8)"]
|
||||
|
||||
[[package]]
|
||||
name = "flask"
|
||||
version = "3.0.3"
|
||||
description = "A simple framework for building complex web applications."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "flask-3.0.3-py3-none-any.whl", hash = "sha256:34e815dfaa43340d1d15a5c3a02b8476004037eb4840b34910c6e21679d288f3"},
|
||||
{file = "flask-3.0.3.tar.gz", hash = "sha256:ceb27b0af3823ea2737928a4d99d125a06175b8512c445cbd9a9ce200ef76842"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
blinker = ">=1.6.2"
|
||||
click = ">=8.1.3"
|
||||
itsdangerous = ">=2.1.2"
|
||||
Jinja2 = ">=3.1.2"
|
||||
Werkzeug = ">=3.0.0"
|
||||
|
||||
[package.extras]
|
||||
async = ["asgiref (>=3.2)"]
|
||||
dotenv = ["python-dotenv"]
|
||||
|
||||
[[package]]
|
||||
name = "frozenlist"
|
||||
version = "1.4.1"
|
||||
@@ -1550,6 +1573,17 @@ files = [
|
||||
{file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itsdangerous"
|
||||
version = "2.2.0"
|
||||
description = "Safely pass data to untrusted environments and back."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef"},
|
||||
{file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jinja2"
|
||||
version = "3.1.4"
|
||||
@@ -1741,9 +1775,13 @@ files = [
|
||||
{file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b590b39ef90c6b22ec0be925b211298e810b4856909c8ca60d27ffbca6c12e6"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:c2faf60c583af0d135e853c86ac2735ce178f0e338a3c7f9ae8f622fd2eb788c"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7ff762670cada8e05b32bf1e4dc50b140790909caa8303cfddc4d702b71ea184"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:a6d2092797b388342c1bc932077ad232f914351932353e2e8706851c870bca1f"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"},
|
||||
@@ -1901,30 +1939,6 @@ files = [
|
||||
intel-openmp = "==2021.*"
|
||||
tbb = "==2021.*"
|
||||
|
||||
[[package]]
|
||||
name = "moviepy"
|
||||
version = "1.0.3"
|
||||
description = "Video editing with Python"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "moviepy-1.0.3.tar.gz", hash = "sha256:2884e35d1788077db3ff89e763c5ba7bfddbd7ae9108c9bc809e7ba58fa433f5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
decorator = ">=4.0.2,<5.0"
|
||||
imageio = {version = ">=2.5,<3.0", markers = "python_version >= \"3.4\""}
|
||||
imageio_ffmpeg = {version = ">=0.2.0", markers = "python_version >= \"3.4\""}
|
||||
numpy = {version = ">=1.17.3", markers = "python_version > \"2.7\""}
|
||||
proglog = "<=1.0.0"
|
||||
requests = ">=2.8.1,<3.0"
|
||||
tqdm = ">=4.11.2,<5.0"
|
||||
|
||||
[package.extras]
|
||||
doc = ["Sphinx (>=1.5.2,<2.0)", "numpydoc (>=0.6.0,<1.0)", "pygame (>=1.9.3,<2.0)", "sphinx_rtd_theme (>=0.1.10b0,<1.0)"]
|
||||
optional = ["matplotlib (>=2.0.0,<3.0)", "opencv-python (>=3.0,<4.0)", "scikit-image (>=0.13.0,<1.0)", "scikit-learn", "scipy (>=0.19.0,<1.5)", "youtube_dl"]
|
||||
test = ["coverage (<5.0)", "coveralls (>=1.1,<2.0)", "pytest (>=3.0.0,<4.0)", "pytest-cov (>=2.5.1,<3.0)", "requests (>=2.8.1,<3.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "mpmath"
|
||||
version = "1.3.0"
|
||||
@@ -2696,20 +2710,6 @@ nodeenv = ">=0.11.1"
|
||||
pyyaml = ">=5.1"
|
||||
virtualenv = ">=20.10.0"
|
||||
|
||||
[[package]]
|
||||
name = "proglog"
|
||||
version = "0.1.10"
|
||||
description = "Log and progress bar manager for console, notebooks, web..."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "proglog-0.1.10-py3-none-any.whl", hash = "sha256:19d5da037e8c813da480b741e3fa71fb1ac0a5b02bf21c41577c7f327485ec50"},
|
||||
{file = "proglog-0.1.10.tar.gz", hash = "sha256:658c28c9c82e4caeb2f25f488fff9ceace22f8d69b15d0c1c86d64275e4ddab4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
tqdm = "*"
|
||||
|
||||
[[package]]
|
||||
name = "protobuf"
|
||||
version = "5.27.2"
|
||||
@@ -3276,6 +3276,7 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
@@ -3809,13 +3810,13 @@ test = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "71.0.1"
|
||||
version = "71.0.0"
|
||||
description = "Easily download, build, install, upgrade, and uninstall Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "setuptools-71.0.1-py3-none-any.whl", hash = "sha256:1eb8ef012efae7f6acbc53ec0abde4bc6746c43087fd215ee09e1df48998711f"},
|
||||
{file = "setuptools-71.0.1.tar.gz", hash = "sha256:c51d7fd29843aa18dad362d4b4ecd917022131425438251f4e3d766c964dd1ad"},
|
||||
{file = "setuptools-71.0.0-py3-none-any.whl", hash = "sha256:f06fbe978a91819d250a30e0dc4ca79df713d909e24438a42d0ec300fc52247f"},
|
||||
{file = "setuptools-71.0.0.tar.gz", hash = "sha256:98da3b8aca443b9848a209ae4165e2edede62633219afa493a58fbba57f72e2e"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@@ -4215,6 +4216,23 @@ perf = ["orjson"]
|
||||
sweeps = ["sweeps (>=0.2.0)"]
|
||||
workspaces = ["wandb-workspaces"]
|
||||
|
||||
[[package]]
|
||||
name = "werkzeug"
|
||||
version = "3.0.3"
|
||||
description = "The comprehensive WSGI web application library."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"},
|
||||
{file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
MarkupSafe = ">=2.1.1"
|
||||
|
||||
[package.extras]
|
||||
watchdog = ["watchdog (>=2.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "xxhash"
|
||||
version = "3.4.1"
|
||||
@@ -4485,4 +4503,4 @@ xarm = ["gym-xarm"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "dfe9c6a54e0382156e62e7bd2c7aab1be6372da76d30c61b06d27232276638cb"
|
||||
content-hash = "25d5a270d770d37b13a93bf72868d3b9e683f8af5252b6332ec926a26fd0c096"
|
||||
|
||||
@@ -57,13 +57,15 @@ pytest-cov = {version = ">=5.0.0", optional = true}
|
||||
datasets = ">=2.19.0"
|
||||
imagecodecs = { version = ">=2024.1.1", optional = true }
|
||||
pyav = ">=12.0.5"
|
||||
moviepy = ">=1.0.3"
|
||||
rerun-sdk = ">=0.15.1"
|
||||
deepdiff = ">=7.0.1"
|
||||
scikit-image = {version = ">=0.23.2", optional = true}
|
||||
flask = ">=3.0.3"
|
||||
pandas = {version = ">=2.2.2", optional = true}
|
||||
scikit-image = {version = ">=0.23.2", optional = true}
|
||||
dynamixel-sdk = {version = ">=3.7.31", optional = true}
|
||||
pynput = {version = ">=1.7.7", optional = true}
|
||||
# TODO(rcadene, salibert): 71.0.1 has a bug
|
||||
setuptools = {version = "!=71.0.1", optional = true}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Save the policy tests artifacts.
|
||||
|
||||
Note: Run on the cluster
|
||||
|
||||
Example of usage:
|
||||
```bash
|
||||
DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py
|
||||
```
|
||||
"""
|
||||
|
||||
import platform
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
@@ -54,7 +66,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
loss = output_dict["loss"]
|
||||
|
||||
loss.backward()
|
||||
loss.mean().backward()
|
||||
grad_stats = {}
|
||||
for key, param in policy.named_parameters():
|
||||
if param.requires_grad:
|
||||
@@ -96,10 +108,21 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
||||
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
|
||||
print(f" - Validate with: `git add {env_policy_dir}`")
|
||||
print(f" - Revert with: `git checkout -- {env_policy_dir}`")
|
||||
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
if (env_policy_dir / "output_dict.safetensors").exists():
|
||||
prev_loss = load_file(env_policy_dir / "output_dict.safetensors")["loss"]
|
||||
print(f"Previous loss={prev_loss}")
|
||||
print(f"New loss={output_dict['loss'].mean()}")
|
||||
print()
|
||||
|
||||
if env_policy_dir.exists():
|
||||
shutil.rmtree(env_policy_dir)
|
||||
|
||||
env_policy_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
|
||||
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
|
||||
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
|
||||
@@ -107,27 +130,32 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if platform.machine() != "x86_64":
|
||||
raise OSError("Generate policy artifacts on x86_64 machine since it is used for the unit tests. ")
|
||||
|
||||
env_policies = [
|
||||
# ("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
|
||||
# ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
|
||||
# (
|
||||
# "pusht",
|
||||
# "diffusion",
|
||||
# [
|
||||
# "policy.n_action_steps=8",
|
||||
# "policy.num_inference_steps=10",
|
||||
# "policy.down_dims=[128, 256, 512]",
|
||||
# ],
|
||||
# "",
|
||||
# ),
|
||||
# ("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
# ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
|
||||
("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
|
||||
("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
|
||||
(
|
||||
"pusht",
|
||||
"diffusion",
|
||||
[
|
||||
"policy.n_action_steps=8",
|
||||
"policy.num_inference_steps=10",
|
||||
"policy.down_dims=[128, 256, 512]",
|
||||
],
|
||||
"",
|
||||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
|
||||
]
|
||||
if len(env_policies) == 0:
|
||||
raise RuntimeError("No policies were provided!")
|
||||
for env, policy, extra_overrides, file_name_extra in env_policies:
|
||||
print(f"env={env} policy={policy} extra_overrides={extra_overrides}")
|
||||
save_policy_to_safetensors(
|
||||
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
|
||||
)
|
||||
print()
|
||||
|
||||
@@ -147,10 +147,11 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
env = make_env(cfg, n_envs=2)
|
||||
|
||||
batch_size = 2
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=2,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=DEVICE != "cpu",
|
||||
drop_last=True,
|
||||
@@ -164,12 +165,19 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
|
||||
# Test updating the policy (and test that it does not mutate the batch)
|
||||
batch_ = deepcopy(batch)
|
||||
policy.forward(batch)
|
||||
out = policy.forward(batch)
|
||||
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
|
||||
assert all(
|
||||
torch.equal(batch[k], batch_[k]) for k in batch
|
||||
), "Batch values are not the same after a forward pass."
|
||||
|
||||
# Test loss can be visualized using visualize_dataset_html.py
|
||||
for key in out:
|
||||
if "loss" in key:
|
||||
assert (
|
||||
out[key].ndim == 1 and out[key].shape[0] == batch_size
|
||||
), f"1 loss value per item in the batch is expected, but {out[key].shape} provided instead."
|
||||
|
||||
# reset the policy and environment
|
||||
policy.reset()
|
||||
observation, _ = env.reset(seed=cfg.seed)
|
||||
@@ -234,6 +242,7 @@ def test_policy_defaults(policy_name: str):
|
||||
[
|
||||
("xarm", "tdmpc"),
|
||||
("pusht", "diffusion"),
|
||||
("pusht", "vqbet"),
|
||||
("aloha", "act"),
|
||||
],
|
||||
)
|
||||
@@ -250,7 +259,7 @@ def test_yaml_matches_dataclass(env_name: str, policy_name: str):
|
||||
def test_save_and_load_pretrained(policy_name: str):
|
||||
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
||||
policy: Policy = policy_cls()
|
||||
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
save_dir = f"/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
policy.save_pretrained(save_dir)
|
||||
policy_ = policy_cls.from_pretrained(save_dir)
|
||||
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
|
||||
@@ -365,6 +374,7 @@ def test_normalize(insert_temporal_dim):
|
||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
"",
|
||||
),
|
||||
("pusht", "vqbet", "[]", ""),
|
||||
("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||
@@ -461,7 +471,3 @@ def test_act_temporal_ensembler():
|
||||
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
|
||||
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
|
||||
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_act_temporal_ensembler()
|
||||
|
||||
@@ -25,13 +25,13 @@ from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||
["lerobot/pusht"],
|
||||
)
|
||||
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
|
||||
def test_visualize_local_dataset(tmpdir, repo_id, root):
|
||||
def test_visualize_dataset_root(tmpdir, repo_id, root):
|
||||
rrd_path = visualize_dataset(
|
||||
repo_id,
|
||||
root=root,
|
||||
episode_index=0,
|
||||
batch_size=32,
|
||||
save=True,
|
||||
output_dir=tmpdir,
|
||||
root=root,
|
||||
)
|
||||
assert rrd_path.exists()
|
||||
|
||||
72
tests/test_visualize_dataset_html.py
Normal file
72
tests/test_visualize_dataset_html.py
Normal file
@@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
["lerobot/pusht"],
|
||||
)
|
||||
def test_visualize_dataset_html(tmpdir, repo_id):
|
||||
tmpdir = Path(tmpdir)
|
||||
visualize_dataset_html(
|
||||
repo_id,
|
||||
episodes=[0],
|
||||
output_dir=tmpdir,
|
||||
serve=False,
|
||||
)
|
||||
assert (tmpdir / "static" / "episode_0.csv").exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id, policy_method",
|
||||
[
|
||||
("lerobot/pusht", "select_action"),
|
||||
("lerobot/pusht", "forward"),
|
||||
],
|
||||
)
|
||||
def test_visualize_dataset_policy_ckpt_path(tmpdir, repo_id, policy_method):
|
||||
tmpdir = Path(tmpdir)
|
||||
|
||||
# Create a policy
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=["device=cpu"])
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
|
||||
# Save a checkpoint
|
||||
logger = Logger(cfg, tmpdir)
|
||||
logger.save_model(tmpdir, policy)
|
||||
|
||||
visualize_dataset_html(
|
||||
repo_id,
|
||||
episodes=[0],
|
||||
output_dir=tmpdir,
|
||||
serve=False,
|
||||
pretrained_policy_name_or_path=tmpdir,
|
||||
policy_method=policy_method,
|
||||
)
|
||||
assert (tmpdir / "static" / "episode_0.csv").exists()
|
||||
assert (tmpdir / "episode_0.safetensors").exists()
|
||||
Reference in New Issue
Block a user