From b6aedcd9a5d2088b7853e9fcbcc9ccfcff9709d9 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Mon, 17 Feb 2025 18:38:45 +0100 Subject: [PATCH] Revert "Replace ArgumentParse by draccus in visualize_dataset_html" This reverts commit d8746be37dcb84ebaa7896485150f0e5ad5dd3a3. --- lerobot/scripts/visualize_dataset_html.py | 164 +++++++++++++++------- 1 file changed, 110 insertions(+), 54 deletions(-) diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index 5ab66bcb9..cc3f39308 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -52,16 +52,15 @@ python lerobot/scripts/visualize_dataset_html.py \ ``` """ +import argparse import csv import json import logging import re import shutil import tempfile -from dataclasses import asdict, dataclass from io import StringIO from pathlib import Path -from pprint import pformat import numpy as np import pandas as pd @@ -72,7 +71,6 @@ from lerobot import available_datasets from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.utils import IterableNamespace from lerobot.common.utils.utils import init_logging -from lerobot.configs import parser def run_server( @@ -326,67 +324,42 @@ def get_dataset_info(repo_id: str) -> IterableNamespace: return IterableNamespace(dataset_info) -@dataclass -class VisualizeDatasetHtmlConfig: - # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). - repo_id: str - # Episode indices to visualize (e.g. `'[0,1,5,6]'` to load episodes of index 0, 1, 5 and 6). By default loads all episodes. - episodes: list[int] | None = None - # Root directory where the dataset will be stored (e.g. 'dataset/path'). - root: str | Path | None = None - # TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument - # Use local files only. By default, this script will try to fetch the dataset from the hub if it exists. - local_files_only: bool = False - # Load videos and parquet files from HF Hub rather than local system. - load_from_hf_hub: bool = False - # Launch web server. - serve: bool = True - # Web host used by the http server. - host: str = "127.0.0.1" - # Web port used by the http server. - port: int = 9090 - # Delete the output directory if it exists already. - force_override: bool = False - - def __post_init__(self): - if self.output_dir is None: - # Create a temporary directory that will be automatically cleaned up - self.output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_") - - self.output_dir = Path(self.output_dir) - - -@parser.wrap() -def visualize_dataset_html(cfg: VisualizeDatasetHtmlConfig): +def visualize_dataset_html( + dataset: LeRobotDataset | None, + episodes: list[int] | None = None, + output_dir: Path | None = None, + serve: bool = True, + host: str = "127.0.0.1", + port: int = 9090, + force_override: bool = False, +) -> Path | None: init_logging() - logging.info(pformat(asdict(cfg))) - - dataset = ( - LeRobotDataset(cfg.repo_id, root=cfg.root, local_files_only=cfg.local_files_only) - if not cfg.load_from_hf_hub - else get_dataset_info(cfg.repo_id) - ) template_dir = Path(__file__).resolve().parent.parent / "templates" - if cfg.output_dir.exists(): - if cfg.force_override: - shutil.rmtree(cfg.output_dir) + if output_dir is None: + # Create a temporary directory that will be automatically cleaned up + output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_") + + output_dir = Path(output_dir) + if output_dir.exists(): + if force_override: + shutil.rmtree(output_dir) else: - logging.info(f"Output directory already exists. Loading from it: '{cfg.output_dir}'") + logging.info(f"Output directory already exists. Loading from it: '{output_dir}'") - cfg.output_dir.mkdir(parents=True, exist_ok=True) + output_dir.mkdir(parents=True, exist_ok=True) - static_dir = cfg.output_dir / "static" + static_dir = output_dir / "static" static_dir.mkdir(parents=True, exist_ok=True) if dataset is None: - if cfg.serve: + if serve: run_server( dataset=None, episodes=None, - host=cfg.host, - port=cfg.port, + host=host, + port=port, static_folder=static_dir, template_folder=template_dir, ) @@ -398,9 +371,92 @@ def visualize_dataset_html(cfg: VisualizeDatasetHtmlConfig): if not ln_videos_dir.exists(): ln_videos_dir.symlink_to((dataset.root / "videos").resolve()) - if cfg.serve: - run_server(dataset, cfg.episodes, cfg.host, cfg.port, static_dir, template_dir) + if serve: + run_server(dataset, episodes, host, port, static_dir, template_dir) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--repo-id", + type=str, + default=None, + help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).", + ) + parser.add_argument( + "--local-files-only", + type=int, + default=0, + help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.", + ) + parser.add_argument( + "--root", + type=Path, + default=None, + help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.", + ) + parser.add_argument( + "--load-from-hf-hub", + type=int, + default=0, + help="Load videos and parquet files from HF Hub rather than local system.", + ) + parser.add_argument( + "--episodes", + type=int, + nargs="*", + default=None, + help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.", + ) + parser.add_argument( + "--serve", + type=int, + default=1, + help="Launch web server.", + ) + parser.add_argument( + "--host", + type=str, + default="127.0.0.1", + help="Web host used by the http server.", + ) + parser.add_argument( + "--port", + type=int, + default=9090, + help="Web port used by the http server.", + ) + parser.add_argument( + "--force-override", + type=int, + default=0, + help="Delete the output directory if it exists already.", + ) + + args = parser.parse_args() + kwargs = vars(args) + repo_id = kwargs.pop("repo_id") + load_from_hf_hub = kwargs.pop("load_from_hf_hub") + root = kwargs.pop("root") + local_files_only = kwargs.pop("local_files_only") + + dataset = None + if repo_id: + dataset = ( + LeRobotDataset(repo_id, root=root, local_files_only=local_files_only) + if not load_from_hf_hub + else get_dataset_info(repo_id) + ) + + visualize_dataset_html(dataset, **vars(args)) if __name__ == "__main__": - visualize_dataset_html() + main()