Add visualize_dataset_html with http.server (#188)
This commit is contained in:
300
lerobot/scripts/visualize_dataset_html.py
Normal file
300
lerobot/scripts/visualize_dataset_html.py
Normal file
@@ -0,0 +1,300 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
|
||||
|
||||
Note: The last frame of the episode doesnt always correspond to a final state.
|
||||
That's because our datasets are composed of transition from state to state up to
|
||||
the antepenultimate state associated to the ultimate action to arrive in the final state.
|
||||
However, there might not be a transition from a final state to another state.
|
||||
|
||||
Note: This script aims to visualize the data used to train the neural networks.
|
||||
~What you see is what you get~. When visualizing image modality, it is often expected to observe
|
||||
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
|
||||
save disk space. The compression factor applied has been tuned to not affect success rate.
|
||||
|
||||
Example of usage:
|
||||
|
||||
- Visualize data stored on a local machine:
|
||||
```bash
|
||||
local$ python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht
|
||||
|
||||
local$ open http://localhost:9090
|
||||
```
|
||||
|
||||
- Visualize data stored on a distant machine with a local viewer:
|
||||
```bash
|
||||
distant$ python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht
|
||||
|
||||
local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel
|
||||
local$ open http://localhost:9090
|
||||
```
|
||||
|
||||
- Select episodes to visualize:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--episodes 7 3 5 1 4
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from flask import Flask, redirect, render_template, url_for
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, episode_index):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frame_ids)
|
||||
|
||||
|
||||
def run_server(
|
||||
dataset: LeRobotDataset,
|
||||
episodes: list[int],
|
||||
host: str,
|
||||
port: str,
|
||||
static_folder: Path,
|
||||
template_folder: Path,
|
||||
):
|
||||
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
# home page redirects to the first episode page
|
||||
[dataset_namespace, dataset_name] = dataset.repo_id.split("/")
|
||||
first_episode_id = episodes[0]
|
||||
return redirect(
|
||||
url_for(
|
||||
"show_episode",
|
||||
dataset_namespace=dataset_namespace,
|
||||
dataset_name=dataset_name,
|
||||
episode_id=first_episode_id,
|
||||
)
|
||||
)
|
||||
|
||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id):
|
||||
dataset_info = {
|
||||
"repo_id": dataset.repo_id,
|
||||
"num_samples": dataset.num_samples,
|
||||
"num_episodes": dataset.num_episodes,
|
||||
"fps": dataset.fps,
|
||||
}
|
||||
video_paths = get_episode_video_paths(dataset, episode_id)
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": Path(video_path).name}
|
||||
for video_path in video_paths
|
||||
]
|
||||
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
|
||||
return render_template(
|
||||
"visualize_dataset_template.html",
|
||||
episode_id=episode_id,
|
||||
episodes=episodes,
|
||||
dataset_info=dataset_info,
|
||||
videos_info=videos_info,
|
||||
ep_csv_url=ep_csv_url,
|
||||
has_policy=False,
|
||||
)
|
||||
|
||||
app.run(host=host, port=port)
|
||||
|
||||
|
||||
def get_ep_csv_fname(episode_id: int):
|
||||
ep_csv_fname = f"episode_{episode_id}.csv"
|
||||
return ep_csv_fname
|
||||
|
||||
|
||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||
|
||||
has_state = "observation.state" in dataset.hf_dataset.features
|
||||
has_action = "action" in dataset.hf_dataset.features
|
||||
|
||||
# init header of csv with state and action names
|
||||
header = ["timestamp"]
|
||||
if has_state:
|
||||
dim_state = len(dataset.hf_dataset["observation.state"][0])
|
||||
header += [f"state_{i}" for i in range(dim_state)]
|
||||
if has_action:
|
||||
dim_action = len(dataset.hf_dataset["action"][0])
|
||||
header += [f"action_{i}" for i in range(dim_action)]
|
||||
|
||||
columns = ["timestamp"]
|
||||
if has_state:
|
||||
columns += ["observation.state"]
|
||||
if has_action:
|
||||
columns += ["action"]
|
||||
|
||||
rows = []
|
||||
data = dataset.hf_dataset.select_columns(columns)
|
||||
for i in range(from_idx, to_idx):
|
||||
row = [data[i]["timestamp"].item()]
|
||||
if has_state:
|
||||
row += data[i]["observation.state"].tolist()
|
||||
if has_action:
|
||||
row += data[i]["action"].tolist()
|
||||
rows.append(row)
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_dir / file_name, "w") as f:
|
||||
f.write(",".join(header) + "\n")
|
||||
for row in rows:
|
||||
row_str = [str(col) for col in row]
|
||||
f.write(",".join(row_str) + "\n")
|
||||
|
||||
|
||||
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
# get first frame of episode (hack to get video_path of the episode)
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
return [
|
||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
||||
for key in dataset.video_frame_keys
|
||||
]
|
||||
|
||||
|
||||
def visualize_dataset_html(
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
episodes: list[int] = None,
|
||||
output_dir: Path | None = None,
|
||||
serve: bool = True,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 9090,
|
||||
force_override: bool = False,
|
||||
) -> Path | None:
|
||||
init_logging()
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
if not dataset.video:
|
||||
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists():
|
||||
if force_override:
|
||||
shutil.rmtree(output_dir)
|
||||
else:
|
||||
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
||||
# so that the http server can get access to the mp4 files.
|
||||
static_dir = output_dir / "static"
|
||||
static_dir.mkdir(parents=True, exist_ok=True)
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
|
||||
|
||||
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
||||
|
||||
if episodes is None:
|
||||
episodes = list(range(dataset.num_episodes))
|
||||
|
||||
logging.info("Writing CSV files")
|
||||
for episode_index in tqdm.tqdm(episodes):
|
||||
# write states and actions in a csv (it can be slow for big datasets)
|
||||
ep_csv_fname = get_ep_csv_fname(episode_index)
|
||||
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
|
||||
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serve",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch web server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Web host used by the http server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=9090,
|
||||
help="Web port used by the http server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-override",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Delete the output directory if it exists already.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
visualize_dataset_html(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user