From 1da5caaf4b3f144c0c228681485cb9f11d5e1b93 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Tue, 6 Aug 2024 17:16:42 +0300 Subject: [PATCH] Revert "Revove inference" This reverts commit ca7f207d74a8071c35f133c02135302ed20f6327. --- lerobot/scripts/visualize_dataset_html.py | 179 +++++++++++++++++++++- tests/test_visualize_dataset_html.py | 36 +++++ 2 files changed, 209 insertions(+), 6 deletions(-) diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index 2531fbd0..00145aad 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -50,19 +50,33 @@ python lerobot/scripts/visualize_dataset_html.py \ --repo-id lerobot/pusht \ --episodes 7 3 5 1 4 ``` + +- Run inference of a policy on the dataset and visualize the results: +```bash +python lerobot/scripts/visualize_dataset_html.py \ + --repo-id lerobot/pusht \ + --episodes 7 3 5 1 4 + -p lerobot/diffusion_pusht \ + --policy-overrides device=cpu +``` """ import argparse import logging import shutil +import warnings from pathlib import Path import torch import tqdm from flask import Flask, redirect, render_template, url_for +from safetensors.torch import load_file, save_file +from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.utils.utils import init_logging +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.utils import get_pretrained_policy_path +from lerobot.common.utils.utils import init_hydra_config, init_logging class EpisodeSampler(torch.utils.data.Sampler): @@ -85,6 +99,7 @@ def run_server( port: str, static_folder: Path, template_folder: Path, + has_policy: bool = False, ): app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve()) app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache @@ -124,7 +139,7 @@ def run_server( dataset_info=dataset_info, videos_info=videos_info, ep_csv_url=ep_csv_url, - has_policy=False, + has_policy=has_policy, ) app.run(host=host, port=port) @@ -135,7 +150,7 @@ def get_ep_csv_fname(episode_id: int): return ep_csv_fname -def write_episode_data_csv(output_dir, file_name, episode_index, dataset): +def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None): """Write a csv file containg timeseries data of an episode (e.g. state and action). This file will be loaded by Dygraph javascript to plot data in real time.""" from_idx = dataset.episode_data_index["from"][episode_index] @@ -143,6 +158,7 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset): has_state = "observation.state" in dataset.hf_dataset.features has_action = "action" in dataset.hf_dataset.features + has_inference = inference_results is not None # init header of csv with state and action names header = ["timestamp"] @@ -152,6 +168,13 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset): if has_action: dim_action = len(dataset.hf_dataset["action"][0]) header += [f"action_{i}" for i in range(dim_action)] + if has_inference: + if "action" in inference_results: + dim_pred_action = inference_results["action"].shape[1] + header += [f"pred_action_{i}" for i in range(dim_pred_action)] + for key in inference_results: + if "loss" in key: + header += [key] columns = ["timestamp"] if has_state: @@ -169,6 +192,18 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset): row += data[i]["action"].tolist() rows.append(row) + if has_inference: + num_frames = len(rows) + if "action" in inference_results: + assert num_frames == inference_results["action"].shape[0] + for i in range(num_frames): + rows[i] += inference_results["action"][i].tolist() + for key in inference_results: + if "loss" in key: + assert num_frames == inference_results[key].shape[0] + for i in range(num_frames): + rows[i] += [inference_results[key][i].item()] + output_dir.mkdir(parents=True, exist_ok=True) with open(output_dir / file_name, "w") as f: f.write(",".join(header) + "\n") @@ -186,6 +221,75 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str] ] +def run_inference( + dataset, episode_index, policy, policy_method="select_action", num_workers=4, batch_size=32, device="cuda" +): + if policy_method not in ["select_action", "forward"]: + raise ValueError( + f"`policy_method` is expected to be 'select_action' or 'forward', but '{policy_method}' is provided instead." + ) + + policy.eval() + policy.to(device) + + logging.info("Loading dataloader") + episode_sampler = EpisodeSampler(dataset, episode_index) + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=num_workers, + # When using `select_action`, we set batch size 1 so that we feed 1 frame at a time, in a continuous fashion. + batch_size=1 if policy_method == "select_action" else batch_size, + sampler=episode_sampler, + drop_last=False, + ) + + warned_ndim_eq_0 = False + warned_ndim_gt_2 = False + + logging.info("Running inference") + inference_results = {} + for batch in tqdm.tqdm(dataloader, total=len(dataloader)): + batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} + with torch.inference_mode(): + if policy_method == "select_action": + gt_action = batch.pop("action") + output_dict = {"action": policy.select_action(batch)} + batch["action"] = gt_action + elif policy_method == "forward": + output_dict = policy.forward(batch) + # TODO(rcadene): Save and display all predicted actions at a given timestamp + # Save predicted action for the next timestamp only + output_dict["action"] = output_dict["action"][:, 0, :] + + for key in output_dict: + if output_dict[key].ndim == 0: + if not warned_ndim_eq_0: + warnings.warn( + f"Ignore output key '{key}'. Its value is a scalar instead of a vector. It might have been aggregated over the batch dimension (e.g. `loss.mean()`).", + stacklevel=1, + ) + warned_ndim_eq_0 = True + continue + + if output_dict[key].ndim > 2: + if not warned_ndim_gt_2: + warnings.warn( + f"Ignore output key '{key}'. Its value is a tensor of {output_dict[key].ndim} dimensions instead of a vector.", + stacklevel=1, + ) + warned_ndim_gt_2 = True + continue + + if key not in inference_results: + inference_results[key] = [] + inference_results[key].append(output_dict[key].to("cpu")) + + for key in inference_results: + inference_results[key] = torch.cat(inference_results[key]) + + return inference_results + + def visualize_dataset_html( repo_id: str, root: Path | None = None, @@ -195,10 +299,28 @@ def visualize_dataset_html( host: str = "127.0.0.1", port: int = 9090, force_override: bool = False, + policy_method: str = "select_action", + pretrained_policy_name_or_path: str | None = None, + policy_overrides: list[str] | None = None, ) -> Path | None: init_logging() - dataset = LeRobotDataset(repo_id, root=root) + has_policy = pretrained_policy_name_or_path is not None + + if has_policy: + logging.info("Loading policy") + pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path) + + hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides) + dataset = make_dataset(hydra_cfg) + policy = make_policy(hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path) + + if policy_method == "select_action": + # Do not load previous observations or future actions, to simulate that the observations come from + # an environment. + dataset.delta_timestamps = None + else: + dataset = LeRobotDataset(repo_id, root=root) if not dataset.video: raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.") @@ -206,6 +328,11 @@ def visualize_dataset_html( if output_dir is None: output_dir = f"outputs/visualize_dataset_html/{repo_id}" + if has_policy: + ckpt_str = pretrained_policy_path.parts[-2] + exp_name = pretrained_policy_path.parts[-4] + output_dir += f"_{exp_name}_{ckpt_str}_{policy_method}" + output_dir = Path(output_dir) if output_dir.exists(): if force_override: @@ -230,13 +357,31 @@ def visualize_dataset_html( logging.info("Writing CSV files") for episode_index in tqdm.tqdm(episodes): + inference_results = None + if has_policy: + inference_results_path = output_dir / f"episode_{episode_index}.safetensors" + if inference_results_path.exists(): + inference_results = load_file(inference_results_path) + else: + inference_results = run_inference( + dataset, + episode_index, + policy, + policy_method, + num_workers=hydra_cfg.training.num_workers, + batch_size=hydra_cfg.training.batch_size, + device=hydra_cfg.device, + ) + inference_results_path.parent.mkdir(parents=True, exist_ok=True) + save_file(inference_results, inference_results_path) + # write states and actions in a csv (it can be slow for big datasets) ep_csv_fname = get_ep_csv_fname(episode_index) # TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors? - write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset) + write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, inference_results) if serve: - run_server(dataset, episodes, host, port, static_dir, template_dir) + run_server(dataset, episodes, host, port, static_dir, template_dir, has_policy) def main(): @@ -292,6 +437,28 @@ def main(): help="Delete the output directory if it exists already.", ) + parser.add_argument( + "--policy-method", + type=str, + default="select_action", + choices=["select_action", "forward"], + help="Python method used to run the inference. By default, set to `select_action` used during evaluation to output the sequence of actions. Can bet set to `forward` used during training to compute the loss.", + ) + parser.add_argument( + "-p", + "--pretrained-policy-name-or-path", + type=str, + help=( + "Either the repo ID of a model hosted on the Hub or a path to a directory containing weights " + "saved using `Policy.save_pretrained`." + ), + ) + parser.add_argument( + "--policy-overrides", + nargs="*", + help="Any key=value arguments to override policy config values (use dots for.nested=overrides)", + ) + args = parser.parse_args() visualize_dataset_html(**vars(args)) diff --git a/tests/test_visualize_dataset_html.py b/tests/test_visualize_dataset_html.py index 4dc3c063..77ababfa 100644 --- a/tests/test_visualize_dataset_html.py +++ b/tests/test_visualize_dataset_html.py @@ -18,7 +18,12 @@ from pathlib import Path import pytest +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.logger import Logger +from lerobot.common.policies.factory import make_policy +from lerobot.common.utils.utils import init_hydra_config from lerobot.scripts.visualize_dataset_html import visualize_dataset_html +from tests.utils import DEFAULT_CONFIG_PATH @pytest.mark.parametrize( @@ -34,3 +39,34 @@ def test_visualize_dataset_html(tmpdir, repo_id): serve=False, ) assert (tmpdir / "static" / "episode_0.csv").exists() + + +@pytest.mark.parametrize( + "repo_id, policy_method", + [ + ("lerobot/pusht", "select_action"), + ("lerobot/pusht", "forward"), + ], +) +def test_visualize_dataset_policy_ckpt_path(tmpdir, repo_id, policy_method): + tmpdir = Path(tmpdir) + + # Create a policy + cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=["device=cpu"]) + dataset = make_dataset(cfg) + policy = make_policy(cfg, dataset_stats=dataset.stats) + + # Save a checkpoint + logger = Logger(cfg, tmpdir) + logger.save_model(tmpdir, policy) + + visualize_dataset_html( + repo_id, + episodes=[0], + output_dir=tmpdir, + serve=False, + pretrained_policy_name_or_path=tmpdir, + policy_method=policy_method, + ) + assert (tmpdir / "static" / "episode_0.csv").exists() + assert (tmpdir / "episode_0.safetensors").exists()