Revert "Replace ArgumentParse by draccus in visualize_dataset_html"

This reverts commit d8746be37dcb84ebaa7896485150f0e5ad5dd3a3.
This commit is contained in:
Remi Cadene
2025-02-17 18:38:45 +01:00
parent 121030cca7
commit b6aedcd9a5

View File

@@ -52,16 +52,15 @@ python lerobot/scripts/visualize_dataset_html.py \
``` ```
""" """
import argparse
import csv import csv
import json import json
import logging import logging
import re import re
import shutil import shutil
import tempfile import tempfile
from dataclasses import asdict, dataclass
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from pprint import pformat
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@@ -72,7 +71,6 @@ from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import IterableNamespace from lerobot.common.datasets.utils import IterableNamespace
from lerobot.common.utils.utils import init_logging from lerobot.common.utils.utils import init_logging
from lerobot.configs import parser
def run_server( def run_server(
@@ -326,67 +324,42 @@ def get_dataset_info(repo_id: str) -> IterableNamespace:
return IterableNamespace(dataset_info) return IterableNamespace(dataset_info)
@dataclass def visualize_dataset_html(
class VisualizeDatasetHtmlConfig: dataset: LeRobotDataset | None,
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). episodes: list[int] | None = None,
repo_id: str output_dir: Path | None = None,
# Episode indices to visualize (e.g. `'[0,1,5,6]'` to load episodes of index 0, 1, 5 and 6). By default loads all episodes. serve: bool = True,
episodes: list[int] | None = None host: str = "127.0.0.1",
# Root directory where the dataset will be stored (e.g. 'dataset/path'). port: int = 9090,
root: str | Path | None = None force_override: bool = False,
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument ) -> Path | None:
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
local_files_only: bool = False
# Load videos and parquet files from HF Hub rather than local system.
load_from_hf_hub: bool = False
# Launch web server.
serve: bool = True
# Web host used by the http server.
host: str = "127.0.0.1"
# Web port used by the http server.
port: int = 9090
# Delete the output directory if it exists already.
force_override: bool = False
def __post_init__(self):
if self.output_dir is None:
# Create a temporary directory that will be automatically cleaned up
self.output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
self.output_dir = Path(self.output_dir)
@parser.wrap()
def visualize_dataset_html(cfg: VisualizeDatasetHtmlConfig):
init_logging() init_logging()
logging.info(pformat(asdict(cfg)))
dataset = (
LeRobotDataset(cfg.repo_id, root=cfg.root, local_files_only=cfg.local_files_only)
if not cfg.load_from_hf_hub
else get_dataset_info(cfg.repo_id)
)
template_dir = Path(__file__).resolve().parent.parent / "templates" template_dir = Path(__file__).resolve().parent.parent / "templates"
if cfg.output_dir.exists(): if output_dir is None:
if cfg.force_override: # Create a temporary directory that will be automatically cleaned up
shutil.rmtree(cfg.output_dir) output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
output_dir = Path(output_dir)
if output_dir.exists():
if force_override:
shutil.rmtree(output_dir)
else: else:
logging.info(f"Output directory already exists. Loading from it: '{cfg.output_dir}'") logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
cfg.output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
static_dir = cfg.output_dir / "static" static_dir = output_dir / "static"
static_dir.mkdir(parents=True, exist_ok=True) static_dir.mkdir(parents=True, exist_ok=True)
if dataset is None: if dataset is None:
if cfg.serve: if serve:
run_server( run_server(
dataset=None, dataset=None,
episodes=None, episodes=None,
host=cfg.host, host=host,
port=cfg.port, port=port,
static_folder=static_dir, static_folder=static_dir,
template_folder=template_dir, template_folder=template_dir,
) )
@@ -398,9 +371,92 @@ def visualize_dataset_html(cfg: VisualizeDatasetHtmlConfig):
if not ln_videos_dir.exists(): if not ln_videos_dir.exists():
ln_videos_dir.symlink_to((dataset.root / "videos").resolve()) ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
if cfg.serve: if serve:
run_server(dataset, cfg.episodes, cfg.host, cfg.port, static_dir, template_dir) run_server(dataset, episodes, host, port, static_dir, template_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
default=None,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser.add_argument(
"--root",
type=Path,
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--load-from-hf-hub",
type=int,
default=0,
help="Load videos and parquet files from HF Hub rather than local system.",
)
parser.add_argument(
"--episodes",
type=int,
nargs="*",
default=None,
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=None,
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
)
parser.add_argument(
"--serve",
type=int,
default=1,
help="Launch web server.",
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Web host used by the http server.",
)
parser.add_argument(
"--port",
type=int,
default=9090,
help="Web port used by the http server.",
)
parser.add_argument(
"--force-override",
type=int,
default=0,
help="Delete the output directory if it exists already.",
)
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
root = kwargs.pop("root")
local_files_only = kwargs.pop("local_files_only")
dataset = None
if repo_id:
dataset = (
LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
if not load_from_hf_hub
else get_dataset_info(repo_id)
)
visualize_dataset_html(dataset, **vars(args))
if __name__ == "__main__": if __name__ == "__main__":
visualize_dataset_html() main()