diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index da73000e..6534343d 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -60,13 +60,31 @@ import shutil import socketserver from pathlib import Path +import torch import tqdm +import yaml from bs4 import BeautifulSoup +from huggingface_hub import snapshot_download +from safetensors.torch import load_file, save_file from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.act.modeling_act import ACTPolicy from lerobot.common.utils.utils import init_logging +class EpisodeSampler(torch.utils.data.Sampler): + def __init__(self, dataset, episode_index): + from_idx = dataset.episode_data_index["from"][episode_index].item() + to_idx = dataset.episode_data_index["to"][episode_index].item() + self.frame_ids = range(from_idx, to_idx) + + def __iter__(self): + return iter(self.frame_ids) + + def __len__(self): + return len(self.frame_ids) + + class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): def end_headers(self): self.send_header("Cache-Control", "no-store, no-cache, must-revalidate") @@ -114,10 +132,10 @@ def create_html_page(page_title: str): main_div = soup.new_tag("div") body.append(main_div) - return soup, body + return soup, head, body -def write_episode_data_csv(output_dir, file_name, episode_index, dataset): +def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None): """Write a csv file containg timeseries data of an episode (e.g. state and action). This file will be loaded by Dygraph javascript to plot data in real time.""" from_idx = dataset.episode_data_index["from"][episode_index] @@ -125,6 +143,7 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset): has_state = "observation.state" in dataset.hf_dataset.features has_action = "action" in dataset.hf_dataset.features + has_inference = inference_results is not None # init header of csv with state and action names header = ["timestamp"] @@ -134,6 +153,12 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset): if has_action: dim_action = len(dataset.hf_dataset["action"][0]) header += [f"action_{i}" for i in range(dim_action)] + if has_inference: + assert "actions" in inference_results + assert "loss" in inference_results + dim_pred_action = inference_results["actions"].shape[2] + header += [f"pred_action_{i}" for i in range(dim_pred_action)] + header += ["loss"] columns = ["timestamp"] if has_state: @@ -151,6 +176,14 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset): row += data[i]["action"].tolist() rows.append(row) + if has_inference: + num_frames = len(rows) + assert num_frames == inference_results["actions"].shape[0] + assert num_frames == inference_results["loss"].shape[0] + for i in range(num_frames): + rows[i] += inference_results["actions"][i, 0].tolist() + rows[i] += [inference_results["loss"][i].item()] + output_dir.mkdir(parents=True, exist_ok=True) with open(output_dir / file_name, "w") as f: f.write(",".join(header) + "\n") @@ -172,7 +205,13 @@ def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset): s += " legend: 'always',\n" s += " labelsDiv: document.getElementById('labels'),\n" s += " labelsSeparateLines: true,\n" - s += " labelsKMB: true\n" + s += " labelsKMB: true,\n" + s += " highlightCircleSize: 1.5,\n" + s += " highlightSeriesOpts: {\n" + s += " strokeWidth: 1.5,\n" + s += " strokeBorderWidth: 1,\n" + s += " highlightCircleSize: 3\n" + s += " }\n" s += " });\n" s += "\n" s += " // Function to play both videos\n" @@ -215,7 +254,14 @@ def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset): def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset): """Write an html file containg video feeds and timeseries associated to an episode.""" - soup, body = create_html_page("") + soup, head, body = create_html_page("") + + css_style = soup.new_tag("style") + css_style.string = "" + css_style.string += "#labels > span.highlight {\n" + css_style.string += " border: 1px solid grey;\n" + css_style.string += "}" + head.append(css_style) # Add videos from camera feeds @@ -295,7 +341,7 @@ def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset): def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames, dataset): """Write an html file containing information related to the dataset and a list of links to html pages of episodes.""" - soup, body = create_html_page("TODO") + soup, head, body = create_html_page("TODO") h3 = soup.new_tag("h3") h3.string = "TODO" @@ -337,17 +383,68 @@ def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames, f.write(soup.prettify()) +def run_inference(dataset, episode_index, policy, num_workers=4, batch_size=32, device="cuda"): + policy.eval() + policy.to(device) + + logging.info("Loading dataloader") + episode_sampler = EpisodeSampler(dataset, episode_index) + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=num_workers, + batch_size=batch_size, + sampler=episode_sampler, + ) + + logging.info("Running inference") + inference_results = {} + for batch in tqdm.tqdm(dataloader, total=len(dataloader)): + batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} + with torch.inference_mode(): + output_dict = policy.forward(batch) + + for key in output_dict: + if key not in inference_results: + inference_results[key] = [] + inference_results[key].append(output_dict[key].to("cpu")) + + for key in inference_results: + inference_results[key] = torch.cat(inference_results[key]) + + return inference_results + + def visualize_dataset( repo_id: str, episode_indices: list[int] = None, output_dir: Path | None = None, serve: bool = True, port: int = 9090, + force_overwrite: bool = True, + policy_repo_id: str | None = None, + policy_ckpt_path: Path | None = None, + batch_size: int = 32, + num_workers: int = 4, ) -> Path | None: init_logging() + has_policy = policy_repo_id or policy_ckpt_path + + if has_policy: + logging.info("Loading policy") + if policy_repo_id: + pretrained_policy_path = Path(snapshot_download(policy_repo_id)) + elif policy_ckpt_path: + pretrained_policy_path = Path(policy_ckpt_path) + policy = ACTPolicy.from_pretrained(pretrained_policy_path) + with open(pretrained_policy_path / "config.yaml") as f: + cfg = yaml.safe_load(f) + delta_timestamps = cfg["training"]["delta_timestamps"] + else: + delta_timestamps = None + logging.info("Loading dataset") - dataset = LeRobotDataset(repo_id) + dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps) if not dataset.video: raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.") @@ -356,31 +453,41 @@ def visualize_dataset( output_dir = f"outputs/visualize_dataset/{repo_id}" output_dir = Path(output_dir) - if output_dir.exists(): + if force_overwrite and output_dir.exists(): shutil.rmtree(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Create a simlink from the dataset video folder containg mp4 files to the output directory # so that the http server can get access to the mp4 files. ln_videos_dir = output_dir / "videos" - ln_videos_dir.symlink_to(dataset.videos_dir.resolve()) + if not ln_videos_dir.exists(): + ln_videos_dir.symlink_to(dataset.videos_dir.resolve()) if episode_indices is None: episode_indices = list(range(dataset.num_episodes)) logging.info("Writing html") ep_html_fnames = [] - for episode_idx in tqdm.tqdm(episode_indices): - # write states and actions in a csv - ep_csv_fname = f"episode_{episode_idx}.csv" - write_episode_data_csv(output_dir, ep_csv_fname, episode_idx, dataset) + for episode_index in tqdm.tqdm(episode_indices): + inference_results = None + if has_policy: + inference_results_path = output_dir / f"episode_{episode_index}.safetensors" + if inference_results_path.exists(): + inference_results = load_file(inference_results_path) + else: + inference_results = run_inference(dataset, episode_index, policy) + save_file(inference_results, inference_results_path) - js_fname = f"episode_{episode_idx}.js" + # write states and actions in a csv + ep_csv_fname = f"episode_{episode_index}.csv" + write_episode_data_csv(output_dir, ep_csv_fname, episode_index, dataset, inference_results) + + js_fname = f"episode_{episode_index}.js" write_episode_data_js(output_dir, js_fname, ep_csv_fname, dataset) # write a html page to view videos and timeseries - ep_html_fname = f"episode_{episode_idx}.html" - write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_idx, dataset) + ep_html_fname = f"episode_{episode_index}.html" + write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_index, dataset) ep_html_fnames.append(ep_html_fname) write_episodes_list_html(output_dir, "index.html", episode_indices, ep_html_fnames, dataset) @@ -396,7 +503,7 @@ def main(): "--repo-id", type=str, required=True, - help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).", + help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).", ) parser.add_argument( "--episode-indices", @@ -423,6 +530,37 @@ def main(): default=9090, help="Web port used by the http server.", ) + parser.add_argument( + "--force-overwrite", + type=int, + default=1, + help="Delete the output directory if it exists already.", + ) + + parser.add_argument( + "--policy-repo-id", + type=str, + default=None, + help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).", + ) + parser.add_argument( + "--policy-ckpt-path", + type=str, + default=None, + help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).", + ) + parser.add_argument( + "--batch-size", + type=int, + default=32, + help="Batch size loaded by DataLoader.", + ) + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="Number of processes of Dataloader for loading the data.", + ) args = parser.parse_args() visualize_dataset(**vars(args))