Add save inference

This commit is contained in:
Remi Cadene
2025-02-18 10:08:45 +01:00
parent b6aedcd9a5
commit a66a792029
3 changed files with 205 additions and 12 deletions

View File

@@ -155,21 +155,29 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
l1_loss *= ~batch["action_is_pad"].unsqueeze(-1)
loss_dict = {"l1_loss": l1_loss.item()}
bsize, seqlen, num_motors = l1_loss.shape
loss_dict = {
"l1_loss": l1_loss.mean().item(),
"l1_loss_per_item": l1_loss.view(bsize, seqlen * num_motors).mean(dim=1),
}
if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
mean_kld = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
loss_dict["kld_loss_per_item"] = mean_kld
mean_kld = mean_kld.mean()
loss_dict["kld_loss"] = mean_kld.item()
loss = l1_loss + mean_kld * self.config.kl_weight
loss_dict["loss_per_item"] = (
loss_dict["l1_loss_per_item"] + loss_dict["kld_loss_per_item"] * self.config.kl_weight
)
else:
loss = l1_loss

View File

@@ -0,0 +1,159 @@
import logging
import shutil
import tempfile
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from pathlib import Path
from pprint import pformat
import torch
import tqdm
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import (
auto_select_torch_device,
init_logging,
is_amp_available,
is_torch_device_available,
)
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
@dataclass
class SaveInferenceConfig:
dataset: DatasetConfig
# Delete the output directory if it exists already.
force_override: bool = False
batch_size: int = 16
num_workers: int = 4
policy: PreTrainedConfig | None = None
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
device: str | None = None # cuda | cpu | mps
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool | None = None
output_dir: str | Path | None = None
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
# When no device or use_amp are given, use the one from training config.
if self.device is None or self.use_amp is None:
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
if self.device is None:
self.device = train_cfg.device
if self.use_amp is None:
self.use_amp = train_cfg.use_amp
# Automatically switch to available device if necessary
if not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@parser.wrap()
def save_inference(cfg: SaveInferenceConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
dataset = make_dataset(cfg)
policy = make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
output_dir = cfg.output_dir
if output_dir is None:
# Create a temporary directory that will be automatically cleaned up
output_dir = tempfile.mkdtemp(prefix="lerobot_save_inference_")
elif Path(output_dir).exists():
if cfg.force_override:
shutil.rmtree(cfg.output_dir)
else:
raise NotImplementedError(f"Output directory already exists: {cfg.output_dir}")
output_dir = Path(output_dir)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
pin_memory=cfg.device != "cpu",
drop_last=False,
)
policy.train()
episode_indices = []
frame_indices = []
feats = {}
for batch in tqdm.tqdm(dataloader):
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(cfg.device, non_blocking=True)
with torch.no_grad(), torch.autocast(device_type=cfg.device) if cfg.use_amp else nullcontext():
_, output_dict = policy.forward(batch)
bsize = batch["episode_index"].shape[0]
episode_indices.append(batch["episode_index"])
frame_indices.append(batch["frame_index"])
for key in output_dict:
if "loss_per_item" not in key:
continue
if key not in feats:
feats[key] = []
if not (output_dict[key].ndim == 1 and output_dict[key].shape[0] == bsize):
raise ValueError(output_dict[key].shape)
feats[key].append(output_dict[key])
episode_indices = torch.cat(episode_indices)
frame_indices = torch.cat(frame_indices)
for key in feats:
feats[key] = torch.cat(feats[key])
# Find unique episode indices
unique_episodes = torch.unique(episode_indices)
for episode in unique_episodes:
ep_feats = {}
for key in feats:
ep_feats[key] = feats[key][episode_indices == episode].data.cpu()
output_dir.mkdir(parents=True, exist_ok=True)
torch.save(ep_feats, output_dir / f"output_features_episode_{episode}.pth")
print(f"Features can be found in: {output_dir}")
if __name__ == "__main__":
save_inference()

View File

@@ -65,6 +65,7 @@ from pathlib import Path
import numpy as np
import pandas as pd
import requests
import torch
from flask import Flask, redirect, render_template, request, url_for
from lerobot import available_datasets
@@ -80,6 +81,7 @@ def run_server(
port: str,
static_folder: Path,
template_folder: Path,
inference_dir: Path | None,
):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@@ -139,7 +141,14 @@ def run_server(
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
def show_episode(
dataset_namespace,
dataset_name,
episode_id,
dataset=dataset,
episodes=episodes,
inference_dir=inference_dir,
):
repo_id = f"{dataset_namespace}/{dataset_name}"
try:
if dataset is None:
@@ -158,7 +167,7 @@ def run_server(
if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above."
episode_data_csv_str, columns = get_episode_data(dataset, episode_id)
episode_data_csv_str, columns = get_episode_data(dataset, episode_id, inference_dir)
dataset_info = {
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames
@@ -228,7 +237,9 @@ def get_ep_csv_fname(episode_id: int):
return ep_csv_fname
def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
def get_episode_data(
dataset: LeRobotDataset | IterableNamespace, episode_index, inference_dir: Path | None = None
):
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
columns = []
@@ -279,7 +290,15 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
np.expand_dims(data["timestamp"], axis=1),
*[np.vstack(data[col]) for col in selected_columns[1:]],
)
).tolist()
)
if inference_dir is not None:
feats = torch.load(inference_dir / f"output_features_episode_{episode_index}.pth")
for key in feats:
header.append(key.replace("loss_per_item", "loss"))
rows = np.concatenate([rows, feats[key][:, None]], axis=1)
rows = rows.tolist()
# Convert data to CSV string
csv_buffer = StringIO()
@@ -332,6 +351,7 @@ def visualize_dataset_html(
host: str = "127.0.0.1",
port: int = 9090,
force_override: bool = False,
inference_dir: Path | None = None,
) -> Path | None:
init_logging()
@@ -372,7 +392,7 @@ def visualize_dataset_html(
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
run_server(dataset, episodes, host, port, static_dir, template_dir, inference_dir)
def main():
@@ -439,6 +459,12 @@ def main():
default=0,
help="Delete the output directory if it exists already.",
)
parser.add_argument(
"--inference-dir",
type=Path,
default=None,
help="",
)
args = parser.parse_args()
kwargs = vars(args)