From a66a792029fbe7f63897f4d85759d38a373ef2b6 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Tue, 18 Feb 2025 10:08:45 +0100 Subject: [PATCH] Add save inference --- lerobot/common/policies/act/modeling_act.py | 22 ++- lerobot/scripts/save_inference.py | 159 ++++++++++++++++++++ lerobot/scripts/visualize_dataset_html.py | 36 ++++- 3 files changed, 205 insertions(+), 12 deletions(-) create mode 100644 lerobot/scripts/save_inference.py diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index f2b16a1e..e134ca1c 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -155,21 +155,29 @@ class ACTPolicy(PreTrainedPolicy): batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) - l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) - ).mean() + l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none") + l1_loss *= ~batch["action_is_pad"].unsqueeze(-1) - loss_dict = {"l1_loss": l1_loss.item()} + bsize, seqlen, num_motors = l1_loss.shape + loss_dict = { + "l1_loss": l1_loss.mean().item(), + "l1_loss_per_item": l1_loss.view(bsize, seqlen * num_motors).mean(dim=1), + } if self.config.use_vae: # Calculate Dā‚–ā‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for # each dimension independently, we sum over the latent dimension to get the total # KL-divergence per batch element, then take the mean over the batch. # (See App. B of https://arxiv.org/abs/1312.6114 for more details). - mean_kld = ( - (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() - ) + mean_kld = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1) + loss_dict["kld_loss_per_item"] = mean_kld + + mean_kld = mean_kld.mean() loss_dict["kld_loss"] = mean_kld.item() + loss = l1_loss + mean_kld * self.config.kl_weight + loss_dict["loss_per_item"] = ( + loss_dict["l1_loss_per_item"] + loss_dict["kld_loss_per_item"] * self.config.kl_weight + ) else: loss = l1_loss diff --git a/lerobot/scripts/save_inference.py b/lerobot/scripts/save_inference.py new file mode 100644 index 00000000..a680b1cd --- /dev/null +++ b/lerobot/scripts/save_inference.py @@ -0,0 +1,159 @@ +import logging +import shutil +import tempfile +from contextlib import nullcontext +from dataclasses import asdict, dataclass +from pathlib import Path +from pprint import pformat + +import torch +import tqdm + +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.policies.factory import make_policy +from lerobot.common.utils.utils import ( + auto_select_torch_device, + init_logging, + is_amp_available, + is_torch_device_available, +) +from lerobot.configs import parser +from lerobot.configs.default import DatasetConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.train import TrainPipelineConfig + + +@dataclass +class SaveInferenceConfig: + dataset: DatasetConfig + # Delete the output directory if it exists already. + force_override: bool = False + + batch_size: int = 16 + num_workers: int = 4 + + policy: PreTrainedConfig | None = None + # TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint. + device: str | None = None # cuda | cpu | mps + # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, + # automatic gradient scaling is used. + use_amp: bool | None = None + + output_dir: str | Path | None = None + + @classmethod + def __get_path_fields__(cls) -> list[str]: + """This enables the parser to load config from the policy using `--policy.path=local/dir`""" + return ["policy"] + + def __post_init__(self): + # HACK: We parse again the cli args here to get the pretrained path if there was one. + policy_path = parser.get_path_arg("policy") + if policy_path: + cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + + # When no device or use_amp are given, use the one from training config. + if self.device is None or self.use_amp is None: + train_cfg = TrainPipelineConfig.from_pretrained(policy_path) + if self.device is None: + self.device = train_cfg.device + if self.use_amp is None: + self.use_amp = train_cfg.use_amp + + # Automatically switch to available device if necessary + if not is_torch_device_available(self.device): + auto_device = auto_select_torch_device() + logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.") + self.device = auto_device + + # Automatically deactivate AMP if necessary + if self.use_amp and not is_amp_available(self.device): + logging.warning( + f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP." + ) + self.use_amp = False + + +@parser.wrap() +def save_inference(cfg: SaveInferenceConfig): + init_logging() + logging.info(pformat(asdict(cfg))) + + dataset = make_dataset(cfg) + policy = make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta) + + output_dir = cfg.output_dir + if output_dir is None: + # Create a temporary directory that will be automatically cleaned up + output_dir = tempfile.mkdtemp(prefix="lerobot_save_inference_") + + elif Path(output_dir).exists(): + if cfg.force_override: + shutil.rmtree(cfg.output_dir) + else: + raise NotImplementedError(f"Output directory already exists: {cfg.output_dir}") + + output_dir = Path(output_dir) + + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=cfg.num_workers, + batch_size=cfg.batch_size, + pin_memory=cfg.device != "cpu", + drop_last=False, + ) + + policy.train() + + episode_indices = [] + frame_indices = [] + feats = {} + + for batch in tqdm.tqdm(dataloader): + for key in batch: + if isinstance(batch[key], torch.Tensor): + batch[key] = batch[key].to(cfg.device, non_blocking=True) + + with torch.no_grad(), torch.autocast(device_type=cfg.device) if cfg.use_amp else nullcontext(): + _, output_dict = policy.forward(batch) + + bsize = batch["episode_index"].shape[0] + episode_indices.append(batch["episode_index"]) + frame_indices.append(batch["frame_index"]) + + for key in output_dict: + if "loss_per_item" not in key: + continue + + if key not in feats: + feats[key] = [] + + if not (output_dict[key].ndim == 1 and output_dict[key].shape[0] == bsize): + raise ValueError(output_dict[key].shape) + + feats[key].append(output_dict[key]) + + episode_indices = torch.cat(episode_indices) + frame_indices = torch.cat(frame_indices) + + for key in feats: + feats[key] = torch.cat(feats[key]) + + # Find unique episode indices + unique_episodes = torch.unique(episode_indices) + + for episode in unique_episodes: + ep_feats = {} + for key in feats: + ep_feats[key] = feats[key][episode_indices == episode].data.cpu() + + output_dir.mkdir(parents=True, exist_ok=True) + torch.save(ep_feats, output_dir / f"output_features_episode_{episode}.pth") + + print(f"Features can be found in: {output_dir}") + + +if __name__ == "__main__": + save_inference() diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index cc3f3930..98b5bfbb 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -65,6 +65,7 @@ from pathlib import Path import numpy as np import pandas as pd import requests +import torch from flask import Flask, redirect, render_template, request, url_for from lerobot import available_datasets @@ -80,6 +81,7 @@ def run_server( port: str, static_folder: Path, template_folder: Path, + inference_dir: Path | None, ): app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve()) app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache @@ -139,7 +141,14 @@ def run_server( ) @app.route("///episode_") - def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes): + def show_episode( + dataset_namespace, + dataset_name, + episode_id, + dataset=dataset, + episodes=episodes, + inference_dir=inference_dir, + ): repo_id = f"{dataset_namespace}/{dataset_name}" try: if dataset is None: @@ -158,7 +167,7 @@ def run_server( if major_version < 2: return "Make sure to convert your LeRobotDataset to v2 & above." - episode_data_csv_str, columns = get_episode_data(dataset, episode_id) + episode_data_csv_str, columns = get_episode_data(dataset, episode_id, inference_dir) dataset_info = { "repo_id": f"{dataset_namespace}/{dataset_name}", "num_samples": dataset.num_frames @@ -228,7 +237,9 @@ def get_ep_csv_fname(episode_id: int): return ep_csv_fname -def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index): +def get_episode_data( + dataset: LeRobotDataset | IterableNamespace, episode_index, inference_dir: Path | None = None +): """Get a csv str containing timeseries data of an episode (e.g. state and action). This file will be loaded by Dygraph javascript to plot data in real time.""" columns = [] @@ -279,7 +290,15 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index) np.expand_dims(data["timestamp"], axis=1), *[np.vstack(data[col]) for col in selected_columns[1:]], ) - ).tolist() + ) + + if inference_dir is not None: + feats = torch.load(inference_dir / f"output_features_episode_{episode_index}.pth") + for key in feats: + header.append(key.replace("loss_per_item", "loss")) + rows = np.concatenate([rows, feats[key][:, None]], axis=1) + + rows = rows.tolist() # Convert data to CSV string csv_buffer = StringIO() @@ -332,6 +351,7 @@ def visualize_dataset_html( host: str = "127.0.0.1", port: int = 9090, force_override: bool = False, + inference_dir: Path | None = None, ) -> Path | None: init_logging() @@ -372,7 +392,7 @@ def visualize_dataset_html( ln_videos_dir.symlink_to((dataset.root / "videos").resolve()) if serve: - run_server(dataset, episodes, host, port, static_dir, template_dir) + run_server(dataset, episodes, host, port, static_dir, template_dir, inference_dir) def main(): @@ -439,6 +459,12 @@ def main(): default=0, help="Delete the output directory if it exists already.", ) + parser.add_argument( + "--inference-dir", + type=Path, + default=None, + help="", + ) args = parser.parse_args() kwargs = vars(args)