diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index cc3f3930..5ab66bcb 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -52,15 +52,16 @@ python lerobot/scripts/visualize_dataset_html.py \ ``` """ -import argparse import csv import json import logging import re import shutil import tempfile +from dataclasses import asdict, dataclass from io import StringIO from pathlib import Path +from pprint import pformat import numpy as np import pandas as pd @@ -71,6 +72,7 @@ from lerobot import available_datasets from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.utils import IterableNamespace from lerobot.common.utils.utils import init_logging +from lerobot.configs import parser def run_server( @@ -324,42 +326,67 @@ def get_dataset_info(repo_id: str) -> IterableNamespace: return IterableNamespace(dataset_info) -def visualize_dataset_html( - dataset: LeRobotDataset | None, - episodes: list[int] | None = None, - output_dir: Path | None = None, - serve: bool = True, - host: str = "127.0.0.1", - port: int = 9090, - force_override: bool = False, -) -> Path | None: +@dataclass +class VisualizeDatasetHtmlConfig: + # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). + repo_id: str + # Episode indices to visualize (e.g. `'[0,1,5,6]'` to load episodes of index 0, 1, 5 and 6). By default loads all episodes. + episodes: list[int] | None = None + # Root directory where the dataset will be stored (e.g. 'dataset/path'). + root: str | Path | None = None + # TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument + # Use local files only. By default, this script will try to fetch the dataset from the hub if it exists. + local_files_only: bool = False + # Load videos and parquet files from HF Hub rather than local system. + load_from_hf_hub: bool = False + # Launch web server. + serve: bool = True + # Web host used by the http server. + host: str = "127.0.0.1" + # Web port used by the http server. + port: int = 9090 + # Delete the output directory if it exists already. + force_override: bool = False + + def __post_init__(self): + if self.output_dir is None: + # Create a temporary directory that will be automatically cleaned up + self.output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_") + + self.output_dir = Path(self.output_dir) + + +@parser.wrap() +def visualize_dataset_html(cfg: VisualizeDatasetHtmlConfig): init_logging() + logging.info(pformat(asdict(cfg))) + + dataset = ( + LeRobotDataset(cfg.repo_id, root=cfg.root, local_files_only=cfg.local_files_only) + if not cfg.load_from_hf_hub + else get_dataset_info(cfg.repo_id) + ) template_dir = Path(__file__).resolve().parent.parent / "templates" - if output_dir is None: - # Create a temporary directory that will be automatically cleaned up - output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_") - - output_dir = Path(output_dir) - if output_dir.exists(): - if force_override: - shutil.rmtree(output_dir) + if cfg.output_dir.exists(): + if cfg.force_override: + shutil.rmtree(cfg.output_dir) else: - logging.info(f"Output directory already exists. Loading from it: '{output_dir}'") + logging.info(f"Output directory already exists. Loading from it: '{cfg.output_dir}'") - output_dir.mkdir(parents=True, exist_ok=True) + cfg.output_dir.mkdir(parents=True, exist_ok=True) - static_dir = output_dir / "static" + static_dir = cfg.output_dir / "static" static_dir.mkdir(parents=True, exist_ok=True) if dataset is None: - if serve: + if cfg.serve: run_server( dataset=None, episodes=None, - host=host, - port=port, + host=cfg.host, + port=cfg.port, static_folder=static_dir, template_folder=template_dir, ) @@ -371,92 +398,9 @@ def visualize_dataset_html( if not ln_videos_dir.exists(): ln_videos_dir.symlink_to((dataset.root / "videos").resolve()) - if serve: - run_server(dataset, episodes, host, port, static_dir, template_dir) - - -def main(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--repo-id", - type=str, - default=None, - help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).", - ) - parser.add_argument( - "--local-files-only", - type=int, - default=0, - help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.", - ) - parser.add_argument( - "--root", - type=Path, - default=None, - help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.", - ) - parser.add_argument( - "--load-from-hf-hub", - type=int, - default=0, - help="Load videos and parquet files from HF Hub rather than local system.", - ) - parser.add_argument( - "--episodes", - type=int, - nargs="*", - default=None, - help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.", - ) - parser.add_argument( - "--output-dir", - type=Path, - default=None, - help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.", - ) - parser.add_argument( - "--serve", - type=int, - default=1, - help="Launch web server.", - ) - parser.add_argument( - "--host", - type=str, - default="127.0.0.1", - help="Web host used by the http server.", - ) - parser.add_argument( - "--port", - type=int, - default=9090, - help="Web port used by the http server.", - ) - parser.add_argument( - "--force-override", - type=int, - default=0, - help="Delete the output directory if it exists already.", - ) - - args = parser.parse_args() - kwargs = vars(args) - repo_id = kwargs.pop("repo_id") - load_from_hf_hub = kwargs.pop("load_from_hf_hub") - root = kwargs.pop("root") - local_files_only = kwargs.pop("local_files_only") - - dataset = None - if repo_id: - dataset = ( - LeRobotDataset(repo_id, root=root, local_files_only=local_files_only) - if not load_from_hf_hub - else get_dataset_info(repo_id) - ) - - visualize_dataset_html(dataset, **vars(args)) + if cfg.serve: + run_server(dataset, cfg.episodes, cfg.host, cfg.port, static_dir, template_dir) if __name__ == "__main__": - main() + visualize_dataset_html()