forked from tangger/lerobot
@@ -50,19 +50,33 @@ python lerobot/scripts/visualize_dataset_html.py \
|
|||||||
--repo-id lerobot/pusht \
|
--repo-id lerobot/pusht \
|
||||||
--episodes 7 3 5 1 4
|
--episodes 7 3 5 1 4
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- Run inference of a policy on the dataset and visualize the results:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/visualize_dataset_html.py \
|
||||||
|
--repo-id lerobot/pusht \
|
||||||
|
--episodes 7 3 5 1 4
|
||||||
|
-p lerobot/diffusion_pusht \
|
||||||
|
--policy-overrides device=cpu
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from flask import Flask, redirect, render_template, url_for
|
from flask import Flask, redirect, render_template, url_for
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.utils.utils import init_logging
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
from lerobot.common.policies.utils import get_pretrained_policy_path
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config, init_logging
|
||||||
|
|
||||||
|
|
||||||
class EpisodeSampler(torch.utils.data.Sampler):
|
class EpisodeSampler(torch.utils.data.Sampler):
|
||||||
@@ -85,6 +99,7 @@ def run_server(
|
|||||||
port: str,
|
port: str,
|
||||||
static_folder: Path,
|
static_folder: Path,
|
||||||
template_folder: Path,
|
template_folder: Path,
|
||||||
|
has_policy: bool = False,
|
||||||
):
|
):
|
||||||
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
||||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||||
@@ -124,7 +139,7 @@ def run_server(
|
|||||||
dataset_info=dataset_info,
|
dataset_info=dataset_info,
|
||||||
videos_info=videos_info,
|
videos_info=videos_info,
|
||||||
ep_csv_url=ep_csv_url,
|
ep_csv_url=ep_csv_url,
|
||||||
has_policy=False,
|
has_policy=has_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
app.run(host=host, port=port)
|
app.run(host=host, port=port)
|
||||||
@@ -135,7 +150,7 @@ def get_ep_csv_fname(episode_id: int):
|
|||||||
return ep_csv_fname
|
return ep_csv_fname
|
||||||
|
|
||||||
|
|
||||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None):
|
||||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
||||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||||
@@ -143,6 +158,7 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
|||||||
|
|
||||||
has_state = "observation.state" in dataset.hf_dataset.features
|
has_state = "observation.state" in dataset.hf_dataset.features
|
||||||
has_action = "action" in dataset.hf_dataset.features
|
has_action = "action" in dataset.hf_dataset.features
|
||||||
|
has_inference = inference_results is not None
|
||||||
|
|
||||||
# init header of csv with state and action names
|
# init header of csv with state and action names
|
||||||
header = ["timestamp"]
|
header = ["timestamp"]
|
||||||
@@ -152,6 +168,13 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
|||||||
if has_action:
|
if has_action:
|
||||||
dim_action = len(dataset.hf_dataset["action"][0])
|
dim_action = len(dataset.hf_dataset["action"][0])
|
||||||
header += [f"action_{i}" for i in range(dim_action)]
|
header += [f"action_{i}" for i in range(dim_action)]
|
||||||
|
if has_inference:
|
||||||
|
if "action" in inference_results:
|
||||||
|
dim_pred_action = inference_results["action"].shape[1]
|
||||||
|
header += [f"pred_action_{i}" for i in range(dim_pred_action)]
|
||||||
|
for key in inference_results:
|
||||||
|
if "loss" in key:
|
||||||
|
header += [key]
|
||||||
|
|
||||||
columns = ["timestamp"]
|
columns = ["timestamp"]
|
||||||
if has_state:
|
if has_state:
|
||||||
@@ -169,6 +192,18 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
|||||||
row += data[i]["action"].tolist()
|
row += data[i]["action"].tolist()
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
|
|
||||||
|
if has_inference:
|
||||||
|
num_frames = len(rows)
|
||||||
|
if "action" in inference_results:
|
||||||
|
assert num_frames == inference_results["action"].shape[0]
|
||||||
|
for i in range(num_frames):
|
||||||
|
rows[i] += inference_results["action"][i].tolist()
|
||||||
|
for key in inference_results:
|
||||||
|
if "loss" in key:
|
||||||
|
assert num_frames == inference_results[key].shape[0]
|
||||||
|
for i in range(num_frames):
|
||||||
|
rows[i] += [inference_results[key][i].item()]
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
with open(output_dir / file_name, "w") as f:
|
with open(output_dir / file_name, "w") as f:
|
||||||
f.write(",".join(header) + "\n")
|
f.write(",".join(header) + "\n")
|
||||||
@@ -186,6 +221,75 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def run_inference(
|
||||||
|
dataset, episode_index, policy, policy_method="select_action", num_workers=4, batch_size=32, device="cuda"
|
||||||
|
):
|
||||||
|
if policy_method not in ["select_action", "forward"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"`policy_method` is expected to be 'select_action' or 'forward', but '{policy_method}' is provided instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
policy.eval()
|
||||||
|
policy.to(device)
|
||||||
|
|
||||||
|
logging.info("Loading dataloader")
|
||||||
|
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=num_workers,
|
||||||
|
# When using `select_action`, we set batch size 1 so that we feed 1 frame at a time, in a continuous fashion.
|
||||||
|
batch_size=1 if policy_method == "select_action" else batch_size,
|
||||||
|
sampler=episode_sampler,
|
||||||
|
drop_last=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
warned_ndim_eq_0 = False
|
||||||
|
warned_ndim_gt_2 = False
|
||||||
|
|
||||||
|
logging.info("Running inference")
|
||||||
|
inference_results = {}
|
||||||
|
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||||
|
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||||
|
with torch.inference_mode():
|
||||||
|
if policy_method == "select_action":
|
||||||
|
gt_action = batch.pop("action")
|
||||||
|
output_dict = {"action": policy.select_action(batch)}
|
||||||
|
batch["action"] = gt_action
|
||||||
|
elif policy_method == "forward":
|
||||||
|
output_dict = policy.forward(batch)
|
||||||
|
# TODO(rcadene): Save and display all predicted actions at a given timestamp
|
||||||
|
# Save predicted action for the next timestamp only
|
||||||
|
output_dict["action"] = output_dict["action"][:, 0, :]
|
||||||
|
|
||||||
|
for key in output_dict:
|
||||||
|
if output_dict[key].ndim == 0:
|
||||||
|
if not warned_ndim_eq_0:
|
||||||
|
warnings.warn(
|
||||||
|
f"Ignore output key '{key}'. Its value is a scalar instead of a vector. It might have been aggregated over the batch dimension (e.g. `loss.mean()`).",
|
||||||
|
stacklevel=1,
|
||||||
|
)
|
||||||
|
warned_ndim_eq_0 = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if output_dict[key].ndim > 2:
|
||||||
|
if not warned_ndim_gt_2:
|
||||||
|
warnings.warn(
|
||||||
|
f"Ignore output key '{key}'. Its value is a tensor of {output_dict[key].ndim} dimensions instead of a vector.",
|
||||||
|
stacklevel=1,
|
||||||
|
)
|
||||||
|
warned_ndim_gt_2 = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key not in inference_results:
|
||||||
|
inference_results[key] = []
|
||||||
|
inference_results[key].append(output_dict[key].to("cpu"))
|
||||||
|
|
||||||
|
for key in inference_results:
|
||||||
|
inference_results[key] = torch.cat(inference_results[key])
|
||||||
|
|
||||||
|
return inference_results
|
||||||
|
|
||||||
|
|
||||||
def visualize_dataset_html(
|
def visualize_dataset_html(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
@@ -195,10 +299,28 @@ def visualize_dataset_html(
|
|||||||
host: str = "127.0.0.1",
|
host: str = "127.0.0.1",
|
||||||
port: int = 9090,
|
port: int = 9090,
|
||||||
force_override: bool = False,
|
force_override: bool = False,
|
||||||
|
policy_method: str = "select_action",
|
||||||
|
pretrained_policy_name_or_path: str | None = None,
|
||||||
|
policy_overrides: list[str] | None = None,
|
||||||
) -> Path | None:
|
) -> Path | None:
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
dataset = LeRobotDataset(repo_id, root=root)
|
has_policy = pretrained_policy_name_or_path is not None
|
||||||
|
|
||||||
|
if has_policy:
|
||||||
|
logging.info("Loading policy")
|
||||||
|
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||||
|
|
||||||
|
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||||
|
dataset = make_dataset(hydra_cfg)
|
||||||
|
policy = make_policy(hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||||
|
|
||||||
|
if policy_method == "select_action":
|
||||||
|
# Do not load previous observations or future actions, to simulate that the observations come from
|
||||||
|
# an environment.
|
||||||
|
dataset.delta_timestamps = None
|
||||||
|
else:
|
||||||
|
dataset = LeRobotDataset(repo_id, root=root)
|
||||||
|
|
||||||
if not dataset.video:
|
if not dataset.video:
|
||||||
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
||||||
@@ -206,6 +328,11 @@ def visualize_dataset_html(
|
|||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
||||||
|
|
||||||
|
if has_policy:
|
||||||
|
ckpt_str = pretrained_policy_path.parts[-2]
|
||||||
|
exp_name = pretrained_policy_path.parts[-4]
|
||||||
|
output_dir += f"_{exp_name}_{ckpt_str}_{policy_method}"
|
||||||
|
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
if output_dir.exists():
|
if output_dir.exists():
|
||||||
if force_override:
|
if force_override:
|
||||||
@@ -230,13 +357,31 @@ def visualize_dataset_html(
|
|||||||
|
|
||||||
logging.info("Writing CSV files")
|
logging.info("Writing CSV files")
|
||||||
for episode_index in tqdm.tqdm(episodes):
|
for episode_index in tqdm.tqdm(episodes):
|
||||||
|
inference_results = None
|
||||||
|
if has_policy:
|
||||||
|
inference_results_path = output_dir / f"episode_{episode_index}.safetensors"
|
||||||
|
if inference_results_path.exists():
|
||||||
|
inference_results = load_file(inference_results_path)
|
||||||
|
else:
|
||||||
|
inference_results = run_inference(
|
||||||
|
dataset,
|
||||||
|
episode_index,
|
||||||
|
policy,
|
||||||
|
policy_method,
|
||||||
|
num_workers=hydra_cfg.training.num_workers,
|
||||||
|
batch_size=hydra_cfg.training.batch_size,
|
||||||
|
device=hydra_cfg.device,
|
||||||
|
)
|
||||||
|
inference_results_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_file(inference_results, inference_results_path)
|
||||||
|
|
||||||
# write states and actions in a csv (it can be slow for big datasets)
|
# write states and actions in a csv (it can be slow for big datasets)
|
||||||
ep_csv_fname = get_ep_csv_fname(episode_index)
|
ep_csv_fname = get_ep_csv_fname(episode_index)
|
||||||
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
||||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
|
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, inference_results)
|
||||||
|
|
||||||
if serve:
|
if serve:
|
||||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
run_server(dataset, episodes, host, port, static_dir, template_dir, has_policy)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -292,6 +437,28 @@ def main():
|
|||||||
help="Delete the output directory if it exists already.",
|
help="Delete the output directory if it exists already.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--policy-method",
|
||||||
|
type=str,
|
||||||
|
default="select_action",
|
||||||
|
choices=["select_action", "forward"],
|
||||||
|
help="Python method used to run the inference. By default, set to `select_action` used during evaluation to output the sequence of actions. Can bet set to `forward` used during training to compute the loss.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--pretrained-policy-name-or-path",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||||
|
"saved using `Policy.save_pretrained`."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--policy-overrides",
|
||||||
|
nargs="*",
|
||||||
|
help="Any key=value arguments to override policy config values (use dots for.nested=overrides)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
visualize_dataset_html(**vars(args))
|
visualize_dataset_html(**vars(args))
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,12 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
from lerobot.common.logger import Logger
|
||||||
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
||||||
|
from tests.utils import DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -34,3 +39,34 @@ def test_visualize_dataset_html(tmpdir, repo_id):
|
|||||||
serve=False,
|
serve=False,
|
||||||
)
|
)
|
||||||
assert (tmpdir / "static" / "episode_0.csv").exists()
|
assert (tmpdir / "static" / "episode_0.csv").exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"repo_id, policy_method",
|
||||||
|
[
|
||||||
|
("lerobot/pusht", "select_action"),
|
||||||
|
("lerobot/pusht", "forward"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_visualize_dataset_policy_ckpt_path(tmpdir, repo_id, policy_method):
|
||||||
|
tmpdir = Path(tmpdir)
|
||||||
|
|
||||||
|
# Create a policy
|
||||||
|
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=["device=cpu"])
|
||||||
|
dataset = make_dataset(cfg)
|
||||||
|
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||||
|
|
||||||
|
# Save a checkpoint
|
||||||
|
logger = Logger(cfg, tmpdir)
|
||||||
|
logger.save_model(tmpdir, policy)
|
||||||
|
|
||||||
|
visualize_dataset_html(
|
||||||
|
repo_id,
|
||||||
|
episodes=[0],
|
||||||
|
output_dir=tmpdir,
|
||||||
|
serve=False,
|
||||||
|
pretrained_policy_name_or_path=tmpdir,
|
||||||
|
policy_method=policy_method,
|
||||||
|
)
|
||||||
|
assert (tmpdir / "static" / "episode_0.csv").exists()
|
||||||
|
assert (tmpdir / "episode_0.safetensors").exists()
|
||||||
|
|||||||
Reference in New Issue
Block a user