forked from tangger/lerobot
Add offline inference to visualize_dataset
This commit is contained in:
@@ -60,13 +60,31 @@ import shutil
|
|||||||
import socketserver
|
import socketserver
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
import yaml
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||||
from lerobot.common.utils.utils import init_logging
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodeSampler(torch.utils.data.Sampler):
|
||||||
|
def __init__(self, dataset, episode_index):
|
||||||
|
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||||
|
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||||
|
self.frame_ids = range(from_idx, to_idx)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.frame_ids)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.frame_ids)
|
||||||
|
|
||||||
|
|
||||||
class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
|
class NoCacheHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
def end_headers(self):
|
def end_headers(self):
|
||||||
self.send_header("Cache-Control", "no-store, no-cache, must-revalidate")
|
self.send_header("Cache-Control", "no-store, no-cache, must-revalidate")
|
||||||
@@ -114,10 +132,10 @@ def create_html_page(page_title: str):
|
|||||||
|
|
||||||
main_div = soup.new_tag("div")
|
main_div = soup.new_tag("div")
|
||||||
body.append(main_div)
|
body.append(main_div)
|
||||||
return soup, body
|
return soup, head, body
|
||||||
|
|
||||||
|
|
||||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None):
|
||||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
||||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||||
@@ -125,6 +143,7 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
|||||||
|
|
||||||
has_state = "observation.state" in dataset.hf_dataset.features
|
has_state = "observation.state" in dataset.hf_dataset.features
|
||||||
has_action = "action" in dataset.hf_dataset.features
|
has_action = "action" in dataset.hf_dataset.features
|
||||||
|
has_inference = inference_results is not None
|
||||||
|
|
||||||
# init header of csv with state and action names
|
# init header of csv with state and action names
|
||||||
header = ["timestamp"]
|
header = ["timestamp"]
|
||||||
@@ -134,6 +153,12 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
|||||||
if has_action:
|
if has_action:
|
||||||
dim_action = len(dataset.hf_dataset["action"][0])
|
dim_action = len(dataset.hf_dataset["action"][0])
|
||||||
header += [f"action_{i}" for i in range(dim_action)]
|
header += [f"action_{i}" for i in range(dim_action)]
|
||||||
|
if has_inference:
|
||||||
|
assert "actions" in inference_results
|
||||||
|
assert "loss" in inference_results
|
||||||
|
dim_pred_action = inference_results["actions"].shape[2]
|
||||||
|
header += [f"pred_action_{i}" for i in range(dim_pred_action)]
|
||||||
|
header += ["loss"]
|
||||||
|
|
||||||
columns = ["timestamp"]
|
columns = ["timestamp"]
|
||||||
if has_state:
|
if has_state:
|
||||||
@@ -151,6 +176,14 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
|||||||
row += data[i]["action"].tolist()
|
row += data[i]["action"].tolist()
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
|
|
||||||
|
if has_inference:
|
||||||
|
num_frames = len(rows)
|
||||||
|
assert num_frames == inference_results["actions"].shape[0]
|
||||||
|
assert num_frames == inference_results["loss"].shape[0]
|
||||||
|
for i in range(num_frames):
|
||||||
|
rows[i] += inference_results["actions"][i, 0].tolist()
|
||||||
|
rows[i] += [inference_results["loss"][i].item()]
|
||||||
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
with open(output_dir / file_name, "w") as f:
|
with open(output_dir / file_name, "w") as f:
|
||||||
f.write(",".join(header) + "\n")
|
f.write(",".join(header) + "\n")
|
||||||
@@ -172,7 +205,13 @@ def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset):
|
|||||||
s += " legend: 'always',\n"
|
s += " legend: 'always',\n"
|
||||||
s += " labelsDiv: document.getElementById('labels'),\n"
|
s += " labelsDiv: document.getElementById('labels'),\n"
|
||||||
s += " labelsSeparateLines: true,\n"
|
s += " labelsSeparateLines: true,\n"
|
||||||
s += " labelsKMB: true\n"
|
s += " labelsKMB: true,\n"
|
||||||
|
s += " highlightCircleSize: 1.5,\n"
|
||||||
|
s += " highlightSeriesOpts: {\n"
|
||||||
|
s += " strokeWidth: 1.5,\n"
|
||||||
|
s += " strokeBorderWidth: 1,\n"
|
||||||
|
s += " highlightCircleSize: 3\n"
|
||||||
|
s += " }\n"
|
||||||
s += " });\n"
|
s += " });\n"
|
||||||
s += "\n"
|
s += "\n"
|
||||||
s += " // Function to play both videos\n"
|
s += " // Function to play both videos\n"
|
||||||
@@ -215,7 +254,14 @@ def write_episode_data_js(output_dir, file_name, ep_csv_fname, dataset):
|
|||||||
|
|
||||||
def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset):
|
def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset):
|
||||||
"""Write an html file containg video feeds and timeseries associated to an episode."""
|
"""Write an html file containg video feeds and timeseries associated to an episode."""
|
||||||
soup, body = create_html_page("")
|
soup, head, body = create_html_page("")
|
||||||
|
|
||||||
|
css_style = soup.new_tag("style")
|
||||||
|
css_style.string = ""
|
||||||
|
css_style.string += "#labels > span.highlight {\n"
|
||||||
|
css_style.string += " border: 1px solid grey;\n"
|
||||||
|
css_style.string += "}"
|
||||||
|
head.append(css_style)
|
||||||
|
|
||||||
# Add videos from camera feeds
|
# Add videos from camera feeds
|
||||||
|
|
||||||
@@ -295,7 +341,7 @@ def write_episode_data_html(output_dir, file_name, js_fname, ep_index, dataset):
|
|||||||
def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames, dataset):
|
def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames, dataset):
|
||||||
"""Write an html file containing information related to the dataset and a list of links to
|
"""Write an html file containing information related to the dataset and a list of links to
|
||||||
html pages of episodes."""
|
html pages of episodes."""
|
||||||
soup, body = create_html_page("TODO")
|
soup, head, body = create_html_page("TODO")
|
||||||
|
|
||||||
h3 = soup.new_tag("h3")
|
h3 = soup.new_tag("h3")
|
||||||
h3.string = "TODO"
|
h3.string = "TODO"
|
||||||
@@ -337,17 +383,68 @@ def write_episodes_list_html(output_dir, file_name, ep_indices, ep_html_fnames,
|
|||||||
f.write(soup.prettify())
|
f.write(soup.prettify())
|
||||||
|
|
||||||
|
|
||||||
|
def run_inference(dataset, episode_index, policy, num_workers=4, batch_size=32, device="cuda"):
|
||||||
|
policy.eval()
|
||||||
|
policy.to(device)
|
||||||
|
|
||||||
|
logging.info("Loading dataloader")
|
||||||
|
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=num_workers,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sampler=episode_sampler,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Running inference")
|
||||||
|
inference_results = {}
|
||||||
|
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||||
|
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||||
|
with torch.inference_mode():
|
||||||
|
output_dict = policy.forward(batch)
|
||||||
|
|
||||||
|
for key in output_dict:
|
||||||
|
if key not in inference_results:
|
||||||
|
inference_results[key] = []
|
||||||
|
inference_results[key].append(output_dict[key].to("cpu"))
|
||||||
|
|
||||||
|
for key in inference_results:
|
||||||
|
inference_results[key] = torch.cat(inference_results[key])
|
||||||
|
|
||||||
|
return inference_results
|
||||||
|
|
||||||
|
|
||||||
def visualize_dataset(
|
def visualize_dataset(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
episode_indices: list[int] = None,
|
episode_indices: list[int] = None,
|
||||||
output_dir: Path | None = None,
|
output_dir: Path | None = None,
|
||||||
serve: bool = True,
|
serve: bool = True,
|
||||||
port: int = 9090,
|
port: int = 9090,
|
||||||
|
force_overwrite: bool = True,
|
||||||
|
policy_repo_id: str | None = None,
|
||||||
|
policy_ckpt_path: Path | None = None,
|
||||||
|
batch_size: int = 32,
|
||||||
|
num_workers: int = 4,
|
||||||
) -> Path | None:
|
) -> Path | None:
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
|
has_policy = policy_repo_id or policy_ckpt_path
|
||||||
|
|
||||||
|
if has_policy:
|
||||||
|
logging.info("Loading policy")
|
||||||
|
if policy_repo_id:
|
||||||
|
pretrained_policy_path = Path(snapshot_download(policy_repo_id))
|
||||||
|
elif policy_ckpt_path:
|
||||||
|
pretrained_policy_path = Path(policy_ckpt_path)
|
||||||
|
policy = ACTPolicy.from_pretrained(pretrained_policy_path)
|
||||||
|
with open(pretrained_policy_path / "config.yaml") as f:
|
||||||
|
cfg = yaml.safe_load(f)
|
||||||
|
delta_timestamps = cfg["training"]["delta_timestamps"]
|
||||||
|
else:
|
||||||
|
delta_timestamps = None
|
||||||
|
|
||||||
logging.info("Loading dataset")
|
logging.info("Loading dataset")
|
||||||
dataset = LeRobotDataset(repo_id)
|
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||||
|
|
||||||
if not dataset.video:
|
if not dataset.video:
|
||||||
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
||||||
@@ -356,31 +453,41 @@ def visualize_dataset(
|
|||||||
output_dir = f"outputs/visualize_dataset/{repo_id}"
|
output_dir = f"outputs/visualize_dataset/{repo_id}"
|
||||||
|
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
if output_dir.exists():
|
if force_overwrite and output_dir.exists():
|
||||||
shutil.rmtree(output_dir)
|
shutil.rmtree(output_dir)
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
||||||
# so that the http server can get access to the mp4 files.
|
# so that the http server can get access to the mp4 files.
|
||||||
ln_videos_dir = output_dir / "videos"
|
ln_videos_dir = output_dir / "videos"
|
||||||
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
|
if not ln_videos_dir.exists():
|
||||||
|
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
|
||||||
|
|
||||||
if episode_indices is None:
|
if episode_indices is None:
|
||||||
episode_indices = list(range(dataset.num_episodes))
|
episode_indices = list(range(dataset.num_episodes))
|
||||||
|
|
||||||
logging.info("Writing html")
|
logging.info("Writing html")
|
||||||
ep_html_fnames = []
|
ep_html_fnames = []
|
||||||
for episode_idx in tqdm.tqdm(episode_indices):
|
for episode_index in tqdm.tqdm(episode_indices):
|
||||||
# write states and actions in a csv
|
inference_results = None
|
||||||
ep_csv_fname = f"episode_{episode_idx}.csv"
|
if has_policy:
|
||||||
write_episode_data_csv(output_dir, ep_csv_fname, episode_idx, dataset)
|
inference_results_path = output_dir / f"episode_{episode_index}.safetensors"
|
||||||
|
if inference_results_path.exists():
|
||||||
|
inference_results = load_file(inference_results_path)
|
||||||
|
else:
|
||||||
|
inference_results = run_inference(dataset, episode_index, policy)
|
||||||
|
save_file(inference_results, inference_results_path)
|
||||||
|
|
||||||
js_fname = f"episode_{episode_idx}.js"
|
# write states and actions in a csv
|
||||||
|
ep_csv_fname = f"episode_{episode_index}.csv"
|
||||||
|
write_episode_data_csv(output_dir, ep_csv_fname, episode_index, dataset, inference_results)
|
||||||
|
|
||||||
|
js_fname = f"episode_{episode_index}.js"
|
||||||
write_episode_data_js(output_dir, js_fname, ep_csv_fname, dataset)
|
write_episode_data_js(output_dir, js_fname, ep_csv_fname, dataset)
|
||||||
|
|
||||||
# write a html page to view videos and timeseries
|
# write a html page to view videos and timeseries
|
||||||
ep_html_fname = f"episode_{episode_idx}.html"
|
ep_html_fname = f"episode_{episode_index}.html"
|
||||||
write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_idx, dataset)
|
write_episode_data_html(output_dir, ep_html_fname, js_fname, episode_index, dataset)
|
||||||
ep_html_fnames.append(ep_html_fname)
|
ep_html_fnames.append(ep_html_fname)
|
||||||
|
|
||||||
write_episodes_list_html(output_dir, "index.html", episode_indices, ep_html_fnames, dataset)
|
write_episodes_list_html(output_dir, "index.html", episode_indices, ep_html_fnames, dataset)
|
||||||
@@ -396,7 +503,7 @@ def main():
|
|||||||
"--repo-id",
|
"--repo-id",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
|
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--episode-indices",
|
"--episode-indices",
|
||||||
@@ -423,6 +530,37 @@ def main():
|
|||||||
default=9090,
|
default=9090,
|
||||||
help="Web port used by the http server.",
|
help="Web port used by the http server.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--force-overwrite",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Delete the output directory if it exists already.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--policy-repo-id",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--policy-ckpt-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Name of hugging face repositery containing a pretrained policy (e.g. `lerobot/diffusion_pusht` for https://huggingface.co/lerobot/diffusion_pusht).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Batch size loaded by DataLoader.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of processes of Dataloader for loading the data.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
visualize_dataset(**vars(args))
|
visualize_dataset(**vars(args))
|
||||||
|
|||||||
Reference in New Issue
Block a user