Add offline inference to visualize_dataset

This commit is contained in:
Remi Cadene
2024-05-29 15:33:52 +00:00
parent 31a25d1dba
commit 255dbef76a

View File

@@ -60,13 +60,31 @@ import shutil
import socketserver import socketserver
from pathlib import Path from pathlib import Path
import torch
import tqdm import tqdm
import yaml
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.utils.utils import init_logging from lerobot.common.utils.utils import init_logging
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset, episode_index):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx)
def __iter__(self):
return iter(self.frame_ids)
def __len__(self):
return len(self.frame_ids)
class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
def end_headers(self): def end_headers(self):
self.send_header("Cache-Control", "no-store, no-cache, must-revalidate") self.send_header("Cache-Control", "no-store, no-cache, must-revalidate")
@@ -114,10 +132,10 @@ def create_html_page(page_title: str):
main_div = soup.new_tag("div") main_div = soup.new_tag("div")
body.append(main_div) body.append(main_div)
return soup, body return soup, head, body
def write_episode_data_csv(output_dir, file_name, episode_index, dataset): def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None):
"""Write a csv file containg timeseries data of an episode (e.g. state and action). """Write a csv file containg timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time.""" This file will be loaded by Dygraph javascript to plot data in real time."""
from_idx = dataset.episode_data_index["from"][episode_index] from_idx = dataset.episode_data_index["from"][episode_index]
@@ -125,6 +143,7 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
has_state = "observation.state" in dataset.hf_dataset.features has_state = "observation.state" in dataset.hf_dataset.features
has_action = "action" in dataset.hf_dataset.features has_action = "action" in dataset.hf_dataset.features
has_inference = inference_results is not None
# init header of csv with state and action names # init header of csv with state and action names
header = ["timestamp"] header = ["timestamp"]
@@ -134,6 +153,12 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
if has_action: if has_action:
dim_action = len(dataset.hf_dataset["action"][0]) dim_action = len(dataset.hf_dataset["action"][0])
header += [f"action_{i}" for i in range(dim_action)] header += [f"action_{i}" for i in range(dim_action)]
if has_inference:
assert "actions" in inference_results
assert "loss" in inference_results
dim_pred_action = inference_results["actions"].shape[2]
header += [f"pred_action_{i}" for i in range(dim_pred_action)]
header += ["loss"]
columns = ["timestamp"] columns = ["timestamp"]
if has_state: if has_state:
@@ -151,6 +176,14 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
row += data[i]["action"].tolist() row += data[i]["action"].tolist()
rows.append(row) rows.append(row)
if has_inference:
num_frames = len(rows)
assert num_frames == inference_results["actions"].shape[0]
assert num_frames == inference_results["loss"].shape[0]
for i in range(num_frames):
rows[i] += inference_results["actions"][i, 0].tolist()
rows[i] += [inference_results["loss"][i].item()]
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w") as f: with open(output_dir / file_name, "w") as f:
f.write(",".join(header) + "\n") f.write(",".join(header) + "\n")
@@ -172,7 +205,13 @@ def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset):
s += " legend: 'always',\n" s += " legend: 'always',\n"
s += " labelsDiv: document.getElementById('labels'),\n" s += " labelsDiv: document.getElementById('labels'),\n"
s += " labelsSeparateLines: true,\n" s += " labelsSeparateLines: true,\n"
s += " labelsKMB: true\n" s += " labelsKMB: true,\n"
s += " highlightCircleSize: 1.5,\n"
s += " highlightSeriesOpts: {\n"
s += " strokeWidth: 1.5,\n"
s += " strokeBorderWidth: 1,\n"
s += " highlightCircleSize: 3\n"
s += " }\n"
s += " });\n" s += " });\n"
s += "\n" s += "\n"
s += " // Function to play both videos\n" s += " // Function to play both videos\n"
@@ -215,7 +254,14 @@ def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset):
def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset): def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset):
"""Write an html file containg video feeds and timeseries associated to an episode.""" """Write an html file containg video feeds and timeseries associated to an episode."""
soup, body = create_html_page("") soup, head, body = create_html_page("")
css_style = soup.new_tag("style")
css_style.string = ""
css_style.string += "#labels > span.highlight {\n"
css_style.string += " border: 1px solid grey;\n"
css_style.string += "}"
head.append(css_style)
# Add videos from camera feeds # Add videos from camera feeds
@@ -295,7 +341,7 @@ def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset):
def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames, dataset): def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames, dataset):
"""Write an html file containing information related to the dataset and a list of links to """Write an html file containing information related to the dataset and a list of links to
html pages of episodes.""" html pages of episodes."""
soup, body = create_html_page("TODO") soup, head, body = create_html_page("TODO")
h3 = soup.new_tag("h3") h3 = soup.new_tag("h3")
h3.string = "TODO" h3.string = "TODO"
@@ -337,17 +383,68 @@ def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames,
f.write(soup.prettify()) f.write(soup.prettify())
def run_inference(dataset, episode_index, policy, num_workers=4, batch_size=32, device="cuda"):
policy.eval()
policy.to(device)
logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
sampler=episode_sampler,
)
logging.info("Running inference")
inference_results = {}
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
with torch.inference_mode():
output_dict = policy.forward(batch)
for key in output_dict:
if key not in inference_results:
inference_results[key] = []
inference_results[key].append(output_dict[key].to("cpu"))
for key in inference_results:
inference_results[key] = torch.cat(inference_results[key])
return inference_results
def visualize_dataset( def visualize_dataset(
repo_id: str, repo_id: str,
episode_indices: list[int] = None, episode_indices: list[int] = None,
output_dir: Path | None = None, output_dir: Path | None = None,
serve: bool = True, serve: bool = True,
port: int = 9090, port: int = 9090,
force_overwrite: bool = True,
policy_repo_id: str | None = None,
policy_ckpt_path: Path | None = None,
batch_size: int = 32,
num_workers: int = 4,
) -> Path | None: ) -> Path | None:
init_logging() init_logging()
has_policy = policy_repo_id or policy_ckpt_path
if has_policy:
logging.info("Loading policy")
if policy_repo_id:
pretrained_policy_path = Path(snapshot_download(policy_repo_id))
elif policy_ckpt_path:
pretrained_policy_path = Path(policy_ckpt_path)
policy = ACTPolicy.from_pretrained(pretrained_policy_path)
with open(pretrained_policy_path / "config.yaml") as f:
cfg = yaml.safe_load(f)
delta_timestamps = cfg["training"]["delta_timestamps"]
else:
delta_timestamps = None
logging.info("Loading dataset") logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id) dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
if not dataset.video: if not dataset.video:
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.") raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
@@ -356,31 +453,41 @@ def visualize_dataset(
output_dir = f"outputs/visualize_dataset/{repo_id}" output_dir = f"outputs/visualize_dataset/{repo_id}"
output_dir = Path(output_dir) output_dir = Path(output_dir)
if output_dir.exists(): if force_overwrite and output_dir.exists():
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
# Create a simlink from the dataset video folder containg mp4 files to the output directory # Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files. # so that the http server can get access to the mp4 files.
ln_videos_dir = output_dir / "videos" ln_videos_dir = output_dir / "videos"
ln_videos_dir.symlink_to(dataset.videos_dir.resolve()) if not ln_videos_dir.exists():
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
if episode_indices is None: if episode_indices is None:
episode_indices = list(range(dataset.num_episodes)) episode_indices = list(range(dataset.num_episodes))
logging.info("Writing html") logging.info("Writing html")
ep_html_fnames = [] ep_html_fnames = []
for episode_idx in tqdm.tqdm(episode_indices): for episode_index in tqdm.tqdm(episode_indices):
# write states and actions in a csv inference_results = None
ep_csv_fname = f"episode_{episode_idx}.csv" if has_policy:
write_episode_data_csv(output_dir, ep_csv_fname, episode_idx, dataset) inference_results_path = output_dir / f"episode_{episode_index}.safetensors"
if inference_results_path.exists():
inference_results = load_file(inference_results_path)
else:
inference_results = run_inference(dataset, episode_index, policy)
save_file(inference_results, inference_results_path)
js_fname = f"episode_{episode_idx}.js" # write states and actions in a csv
ep_csv_fname = f"episode_{episode_index}.csv"
write_episode_data_csv(output_dir, ep_csv_fname, episode_index, dataset, inference_results)
js_fname = f"episode_{episode_index}.js"
write_episode_data_js(output_dir, js_fname, ep_csv_fname, dataset) write_episode_data_js(output_dir, js_fname, ep_csv_fname, dataset)
# write a html page to view videos and timeseries # write a html page to view videos and timeseries
ep_html_fname = f"episode_{episode_idx}.html" ep_html_fname = f"episode_{episode_index}.html"
write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_idx, dataset) write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_index, dataset)
ep_html_fnames.append(ep_html_fname) ep_html_fnames.append(ep_html_fname)
write_episodes_list_html(output_dir, "index.html", episode_indices, ep_html_fnames, dataset) write_episodes_list_html(output_dir, "index.html", episode_indices, ep_html_fnames, dataset)
@@ -396,7 +503,7 @@ def main():
"--repo-id", "--repo-id",
type=str, type=str,
required=True, required=True,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).", help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
) )
parser.add_argument( parser.add_argument(
"--episode-indices", "--episode-indices",
@@ -423,6 +530,37 @@ def main():
default=9090, default=9090,
help="Web port used by the http server.", help="Web port used by the http server.",
) )
parser.add_argument(
"--force-overwrite",
type=int,
default=1,
help="Delete the output directory if it exists already.",
)
parser.add_argument(
"--policy-repo-id",
type=str,
default=None,
help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).",
)
parser.add_argument(
"--policy-ckpt-path",
type=str,
default=None,
help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).",
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size loaded by DataLoader.",
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of processes of Dataloader for loading the data.",
)
args = parser.parse_args() args = parser.parse_args()
visualize_dataset(**vars(args)) visualize_dataset(**vars(args))