Compare commits

...

4 Commits

Author SHA1 Message Date
Remi Cadene
5f32d75b58 Save action, display next_action 2025-02-18 22:12:59 +01:00
Remi Cadene
a66a792029 Add save inference 2025-02-18 10:18:49 +01:00
Remi Cadene
b6aedcd9a5 Revert "Replace ArgumentParse by draccus in visualize_dataset_html"
This reverts commit d8746be37dcb84ebaa7896485150f0e5ad5dd3a3.
2025-02-18 10:18:49 +01:00
Remi Cadene
121030cca7 Replace ArgumentParse by draccus in visualize_dataset_html 2025-02-18 10:18:49 +01:00
3 changed files with 229 additions and 14 deletions

View File

@@ -155,25 +155,34 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
l1_loss *= ~batch["action_is_pad"].unsqueeze(-1)
loss_dict = {"l1_loss": l1_loss.item()}
bsize, seqlen, num_motors = l1_loss.shape
output_dict = {
"l1_loss": l1_loss.mean().item(),
"l1_loss_per_item": l1_loss.view(bsize, seqlen * num_motors).mean(dim=1),
"action": self.unnormalize_outputs({"action": actions_hat})["action"],
}
if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld.item()
mean_kld = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
output_dict["kld_loss_per_item"] = mean_kld
mean_kld = mean_kld.mean()
output_dict["kld_loss"] = mean_kld.item()
loss = l1_loss + mean_kld * self.config.kl_weight
output_dict["loss_per_item"] = (
output_dict["l1_loss_per_item"] + output_dict["kld_loss_per_item"] * self.config.kl_weight
)
else:
loss = l1_loss
return loss, loss_dict
return loss, output_dict
class ACTTemporalEnsembler:

View File

@@ -0,0 +1,162 @@
import logging
import shutil
import tempfile
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from pathlib import Path
from pprint import pformat
import torch
import tqdm
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import (
auto_select_torch_device,
init_logging,
is_amp_available,
is_torch_device_available,
)
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
@dataclass
class SaveInferenceConfig:
dataset: DatasetConfig
# Delete the output directory if it exists already.
force_override: bool = False
batch_size: int = 16
num_workers: int = 4
policy: PreTrainedConfig | None = None
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
device: str | None = None # cuda | cpu | mps
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool | None = None
output_dir: str | Path | None = None
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
# When no device or use_amp are given, use the one from training config.
if self.device is None or self.use_amp is None:
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
if self.device is None:
self.device = train_cfg.device
if self.use_amp is None:
self.use_amp = train_cfg.use_amp
# Automatically switch to available device if necessary
if not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@parser.wrap()
def save_inference(cfg: SaveInferenceConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
dataset = make_dataset(cfg)
policy = make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
output_dir = cfg.output_dir
if output_dir is None:
# Create a temporary directory that will be automatically cleaned up
output_dir = tempfile.mkdtemp(prefix="lerobot_save_inference_")
elif Path(output_dir).exists() and cfg.force_override:
shutil.rmtree(cfg.output_dir)
output_dir = Path(output_dir)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
pin_memory=cfg.device != "cpu",
drop_last=False,
)
policy.train()
episode_indices = []
frame_indices = []
feats = {}
for batch in tqdm.tqdm(dataloader):
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(cfg.device, non_blocking=True)
with torch.no_grad(), torch.autocast(device_type=cfg.device) if cfg.use_amp else nullcontext():
_, output_dict = policy.forward(batch)
batch_size = batch["episode_index"].shape[0]
episode_indices.append(batch["episode_index"])
frame_indices.append(batch["frame_index"])
for key, value in output_dict.items():
if not isinstance(value, torch.Tensor) or value.shape[0] != batch_size:
print(f"Skipping {key}")
continue
if key not in feats:
feats[key] = []
feats[key].append(value)
episode_indices = torch.cat(episode_indices).cpu()
frame_indices = torch.cat(frame_indices).cpu()
# TODO(rcadene): use collate?
for key, value in feats.items():
if isinstance(value[0], (float, int)):
feats[key] = torch.tensor(value)
elif isinstance(value[0], torch.Tensor):
feats[key] = torch.cat(value, dim=0).cpu()
elif isinstance(value[0], str):
pass
else:
raise NotImplementedError(f"{key}: {value}")
# Find unique episode indices
unique_episodes = torch.unique(episode_indices)
for episode in unique_episodes:
ep_feats = {}
for key in feats:
ep_feats[key] = feats[key][episode_indices == episode]
output_dir.mkdir(parents=True, exist_ok=True)
torch.save(ep_feats, output_dir / f"output_features_episode_{episode}.pth")
print(f"Features can be found in: {output_dir}")
if __name__ == "__main__":
save_inference()

View File

@@ -65,6 +65,7 @@ from pathlib import Path
import numpy as np
import pandas as pd
import requests
import torch
from flask import Flask, redirect, render_template, request, url_for
from lerobot import available_datasets
@@ -80,6 +81,7 @@ def run_server(
port: str,
static_folder: Path,
template_folder: Path,
inference_dir: Path | None,
):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@@ -139,7 +141,14 @@ def run_server(
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
def show_episode(
dataset_namespace,
dataset_name,
episode_id,
dataset=dataset,
episodes=episodes,
inference_dir=inference_dir,
):
repo_id = f"{dataset_namespace}/{dataset_name}"
try:
if dataset is None:
@@ -158,7 +167,7 @@ def run_server(
if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above."
episode_data_csv_str, columns = get_episode_data(dataset, episode_id)
episode_data_csv_str, columns = get_episode_data(dataset, episode_id, inference_dir)
dataset_info = {
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames
@@ -228,7 +237,9 @@ def get_ep_csv_fname(episode_id: int):
return ep_csv_fname
def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
def get_episode_data(
dataset: LeRobotDataset | IterableNamespace, episode_index, inference_dir: Path | None = None
):
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
columns = []
@@ -279,7 +290,33 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
np.expand_dims(data["timestamp"], axis=1),
*[np.vstack(data[col]) for col in selected_columns[1:]],
)
).tolist()
)
if inference_dir is not None:
feats = torch.load(inference_dir / f"output_features_episode_{episode_index}.pth")
for key in feats:
if "loss_per_item" in key:
if feats[key].ndim != 1:
raise ValueError()
header.append(key.replace("loss_per_item", "loss"))
rows = np.concatenate([rows, feats[key][:, None]], axis=1)
elif key == "action":
if feats[key].ndim != 3:
raise ValueError()
next_action = feats[key][:, 0, :]
num_motors = next_action.shape[1]
for i in range(num_motors):
header.append(f"action_{i}")
rows = np.concatenate([rows, next_action], axis=1)
else:
raise NotImplementedError(key)
rows = rows.tolist()
# Convert data to CSV string
csv_buffer = StringIO()
@@ -332,6 +369,7 @@ def visualize_dataset_html(
host: str = "127.0.0.1",
port: int = 9090,
force_override: bool = False,
inference_dir: Path | None = None,
) -> Path | None:
init_logging()
@@ -372,7 +410,7 @@ def visualize_dataset_html(
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
run_server(dataset, episodes, host, port, static_dir, template_dir, inference_dir)
def main():
@@ -439,6 +477,12 @@ def main():
default=0,
help="Delete the output directory if it exists already.",
)
parser.add_argument(
"--inference-dir",
type=Path,
default=None,
help="",
)
args = parser.parse_args()
kwargs = vars(args)