Replace ArgumentParse by draccus in visualize_dataset_html

This commit is contained in:
Remi Cadene
2025-02-17 18:26:42 +01:00
parent fe483b1d0d
commit 121030cca7

View File

@@ -52,15 +52,16 @@ python lerobot/scripts/visualize_dataset_html.py \
```
"""
import argparse
import csv
import json
import logging
import re
import shutil
import tempfile
from dataclasses import asdict, dataclass
from io import StringIO
from pathlib import Path
from pprint import pformat
import numpy as np
import pandas as pd
@@ -71,6 +72,7 @@ from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import IterableNamespace
from lerobot.common.utils.utils import init_logging
from lerobot.configs import parser
def run_server(
@@ -324,42 +326,67 @@ def get_dataset_info(repo_id: str) -> IterableNamespace:
return IterableNamespace(dataset_info)
def visualize_dataset_html(
dataset: LeRobotDataset | None,
episodes: list[int] | None = None,
output_dir: Path | None = None,
serve: bool = True,
host: str = "127.0.0.1",
port: int = 9090,
force_override: bool = False,
) -> Path | None:
@dataclass
class VisualizeDatasetHtmlConfig:
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str
# Episode indices to visualize (e.g. `'[0,1,5,6]'` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.
episodes: list[int] | None = None
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
local_files_only: bool = False
# Load videos and parquet files from HF Hub rather than local system.
load_from_hf_hub: bool = False
# Launch web server.
serve: bool = True
# Web host used by the http server.
host: str = "127.0.0.1"
# Web port used by the http server.
port: int = 9090
# Delete the output directory if it exists already.
force_override: bool = False
def __post_init__(self):
if self.output_dir is None:
# Create a temporary directory that will be automatically cleaned up
self.output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
self.output_dir = Path(self.output_dir)
@parser.wrap()
def visualize_dataset_html(cfg: VisualizeDatasetHtmlConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
dataset = (
LeRobotDataset(cfg.repo_id, root=cfg.root, local_files_only=cfg.local_files_only)
if not cfg.load_from_hf_hub
else get_dataset_info(cfg.repo_id)
)
template_dir = Path(__file__).resolve().parent.parent / "templates"
if output_dir is None:
# Create a temporary directory that will be automatically cleaned up
output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
output_dir = Path(output_dir)
if output_dir.exists():
if force_override:
shutil.rmtree(output_dir)
if cfg.output_dir.exists():
if cfg.force_override:
shutil.rmtree(cfg.output_dir)
else:
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
logging.info(f"Output directory already exists. Loading from it: '{cfg.output_dir}'")
output_dir.mkdir(parents=True, exist_ok=True)
cfg.output_dir.mkdir(parents=True, exist_ok=True)
static_dir = output_dir / "static"
static_dir = cfg.output_dir / "static"
static_dir.mkdir(parents=True, exist_ok=True)
if dataset is None:
if serve:
if cfg.serve:
run_server(
dataset=None,
episodes=None,
host=host,
port=port,
host=cfg.host,
port=cfg.port,
static_folder=static_dir,
template_folder=template_dir,
)
@@ -371,92 +398,9 @@ def visualize_dataset_html(
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
default=None,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--load-from-hf-hub",
type=int,
default=0,
help="Load videos and parquet files from HF Hub rather than local system.",
)
parser.add_argument(
"--episodes",
type=int,
nargs="*",
default=None,
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
)
parser.add_argument(
"--serve",
type=int,
default=1,
help="Launch web server.",
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Web host used by the http server.",
)
parser.add_argument(
"--port",
type=int,
default=9090,
help="Web port used by the http server.",
)
parser.add_argument(
"--force-override",
type=int,
default=0,
help="Delete the output directory if it exists already.",
)
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
root = kwargs.pop("root")
local_files_only = kwargs.pop("local_files_only")
dataset = None
if repo_id:
dataset = (
LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
if not load_from_hf_hub
else get_dataset_info(repo_id)
)
visualize_dataset_html(dataset, **vars(args))
if cfg.serve:
run_server(dataset, cfg.episodes, cfg.host, cfg.port, static_dir, template_dir)
if __name__ == "__main__":
main()
visualize_dataset_html()