forked from tangger/lerobot
Add offline inference to visualize_dataset
This commit is contained in:
@@ -60,13 +60,31 @@ import shutil
|
||||
import socketserver
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
import yaml
|
||||
from bs4 import BeautifulSoup
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, episode_index):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frame_ids)
|
||||
|
||||
|
||||
class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
def end_headers(self):
|
||||
self.send_header("Cache-Control", "no-store, no-cache, must-revalidate")
|
||||
@@ -114,10 +132,10 @@ def create_html_page(page_title: str):
|
||||
|
||||
main_div = soup.new_tag("div")
|
||||
body.append(main_div)
|
||||
return soup, body
|
||||
return soup, head, body
|
||||
|
||||
|
||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None):
|
||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
@@ -125,6 +143,7 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||
|
||||
has_state = "observation.state" in dataset.hf_dataset.features
|
||||
has_action = "action" in dataset.hf_dataset.features
|
||||
has_inference = inference_results is not None
|
||||
|
||||
# init header of csv with state and action names
|
||||
header = ["timestamp"]
|
||||
@@ -134,6 +153,12 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||
if has_action:
|
||||
dim_action = len(dataset.hf_dataset["action"][0])
|
||||
header += [f"action_{i}" for i in range(dim_action)]
|
||||
if has_inference:
|
||||
assert "actions" in inference_results
|
||||
assert "loss" in inference_results
|
||||
dim_pred_action = inference_results["actions"].shape[2]
|
||||
header += [f"pred_action_{i}" for i in range(dim_pred_action)]
|
||||
header += ["loss"]
|
||||
|
||||
columns = ["timestamp"]
|
||||
if has_state:
|
||||
@@ -151,6 +176,14 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||
row += data[i]["action"].tolist()
|
||||
rows.append(row)
|
||||
|
||||
if has_inference:
|
||||
num_frames = len(rows)
|
||||
assert num_frames == inference_results["actions"].shape[0]
|
||||
assert num_frames == inference_results["loss"].shape[0]
|
||||
for i in range(num_frames):
|
||||
rows[i] += inference_results["actions"][i, 0].tolist()
|
||||
rows[i] += [inference_results["loss"][i].item()]
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_dir / file_name, "w") as f:
|
||||
f.write(",".join(header) + "\n")
|
||||
@@ -172,7 +205,13 @@ def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset):
|
||||
s += " legend: 'always',\n"
|
||||
s += " labelsDiv: document.getElementById('labels'),\n"
|
||||
s += " labelsSeparateLines: true,\n"
|
||||
s += " labelsKMB: true\n"
|
||||
s += " labelsKMB: true,\n"
|
||||
s += " highlightCircleSize: 1.5,\n"
|
||||
s += " highlightSeriesOpts: {\n"
|
||||
s += " strokeWidth: 1.5,\n"
|
||||
s += " strokeBorderWidth: 1,\n"
|
||||
s += " highlightCircleSize: 3\n"
|
||||
s += " }\n"
|
||||
s += " });\n"
|
||||
s += "\n"
|
||||
s += " // Function to play both videos\n"
|
||||
@@ -215,7 +254,14 @@ def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset):
|
||||
|
||||
def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset):
|
||||
"""Write an html file containg video feeds and timeseries associated to an episode."""
|
||||
soup, body = create_html_page("")
|
||||
soup, head, body = create_html_page("")
|
||||
|
||||
css_style = soup.new_tag("style")
|
||||
css_style.string = ""
|
||||
css_style.string += "#labels > span.highlight {\n"
|
||||
css_style.string += " border: 1px solid grey;\n"
|
||||
css_style.string += "}"
|
||||
head.append(css_style)
|
||||
|
||||
# Add videos from camera feeds
|
||||
|
||||
@@ -295,7 +341,7 @@ def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset):
|
||||
def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames, dataset):
|
||||
"""Write an html file containing information related to the dataset and a list of links to
|
||||
html pages of episodes."""
|
||||
soup, body = create_html_page("TODO")
|
||||
soup, head, body = create_html_page("TODO")
|
||||
|
||||
h3 = soup.new_tag("h3")
|
||||
h3.string = "TODO"
|
||||
@@ -337,17 +383,68 @@ def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames,
|
||||
f.write(soup.prettify())
|
||||
|
||||
|
||||
def run_inference(dataset, episode_index, policy, num_workers=4, batch_size=32, device="cuda"):
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
logging.info("Loading dataloader")
|
||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
sampler=episode_sampler,
|
||||
)
|
||||
|
||||
logging.info("Running inference")
|
||||
inference_results = {}
|
||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
with torch.inference_mode():
|
||||
output_dict = policy.forward(batch)
|
||||
|
||||
for key in output_dict:
|
||||
if key not in inference_results:
|
||||
inference_results[key] = []
|
||||
inference_results[key].append(output_dict[key].to("cpu"))
|
||||
|
||||
for key in inference_results:
|
||||
inference_results[key] = torch.cat(inference_results[key])
|
||||
|
||||
return inference_results
|
||||
|
||||
|
||||
def visualize_dataset(
|
||||
repo_id: str,
|
||||
episode_indices: list[int] = None,
|
||||
output_dir: Path | None = None,
|
||||
serve: bool = True,
|
||||
port: int = 9090,
|
||||
force_overwrite: bool = True,
|
||||
policy_repo_id: str | None = None,
|
||||
policy_ckpt_path: Path | None = None,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 4,
|
||||
) -> Path | None:
|
||||
init_logging()
|
||||
|
||||
has_policy = policy_repo_id or policy_ckpt_path
|
||||
|
||||
if has_policy:
|
||||
logging.info("Loading policy")
|
||||
if policy_repo_id:
|
||||
pretrained_policy_path = Path(snapshot_download(policy_repo_id))
|
||||
elif policy_ckpt_path:
|
||||
pretrained_policy_path = Path(policy_ckpt_path)
|
||||
policy = ACTPolicy.from_pretrained(pretrained_policy_path)
|
||||
with open(pretrained_policy_path / "config.yaml") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
delta_timestamps = cfg["training"]["delta_timestamps"]
|
||||
else:
|
||||
delta_timestamps = None
|
||||
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
if not dataset.video:
|
||||
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
||||
@@ -356,31 +453,41 @@ def visualize_dataset(
|
||||
output_dir = f"outputs/visualize_dataset/{repo_id}"
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists():
|
||||
if force_overwrite and output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
||||
# so that the http server can get access to the mp4 files.
|
||||
ln_videos_dir = output_dir / "videos"
|
||||
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
|
||||
|
||||
if episode_indices is None:
|
||||
episode_indices = list(range(dataset.num_episodes))
|
||||
|
||||
logging.info("Writing html")
|
||||
ep_html_fnames = []
|
||||
for episode_idx in tqdm.tqdm(episode_indices):
|
||||
# write states and actions in a csv
|
||||
ep_csv_fname = f"episode_{episode_idx}.csv"
|
||||
write_episode_data_csv(output_dir, ep_csv_fname, episode_idx, dataset)
|
||||
for episode_index in tqdm.tqdm(episode_indices):
|
||||
inference_results = None
|
||||
if has_policy:
|
||||
inference_results_path = output_dir / f"episode_{episode_index}.safetensors"
|
||||
if inference_results_path.exists():
|
||||
inference_results = load_file(inference_results_path)
|
||||
else:
|
||||
inference_results = run_inference(dataset, episode_index, policy)
|
||||
save_file(inference_results, inference_results_path)
|
||||
|
||||
js_fname = f"episode_{episode_idx}.js"
|
||||
# write states and actions in a csv
|
||||
ep_csv_fname = f"episode_{episode_index}.csv"
|
||||
write_episode_data_csv(output_dir, ep_csv_fname, episode_index, dataset, inference_results)
|
||||
|
||||
js_fname = f"episode_{episode_index}.js"
|
||||
write_episode_data_js(output_dir, js_fname, ep_csv_fname, dataset)
|
||||
|
||||
# write a html page to view videos and timeseries
|
||||
ep_html_fname = f"episode_{episode_idx}.html"
|
||||
write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_idx, dataset)
|
||||
ep_html_fname = f"episode_{episode_index}.html"
|
||||
write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_index, dataset)
|
||||
ep_html_fnames.append(ep_html_fname)
|
||||
|
||||
write_episodes_list_html(output_dir, "index.html", episode_indices, ep_html_fnames, dataset)
|
||||
@@ -396,7 +503,7 @@ def main():
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
|
||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episode-indices",
|
||||
@@ -423,6 +530,37 @@ def main():
|
||||
default=9090,
|
||||
help="Web port used by the http server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-overwrite",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Delete the output directory if it exists already.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--policy-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy-ckpt-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size loaded by DataLoader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of processes of Dataloader for loading the data.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
visualize_dataset(**vars(args))
|
||||
|
||||
Reference in New Issue
Block a user