Add save inference
This commit is contained in:
@@ -155,21 +155,29 @@ class ACTPolicy(PreTrainedPolicy):
|
|||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||||
|
|
||||||
l1_loss = (
|
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
l1_loss *= ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
).mean()
|
|
||||||
|
|
||||||
loss_dict = {"l1_loss": l1_loss.item()}
|
bsize, seqlen, num_motors = l1_loss.shape
|
||||||
|
loss_dict = {
|
||||||
|
"l1_loss": l1_loss.mean().item(),
|
||||||
|
"l1_loss_per_item": l1_loss.view(bsize, seqlen * num_motors).mean(dim=1),
|
||||||
|
}
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
# KL-divergence per batch element, then take the mean over the batch.
|
# KL-divergence per batch element, then take the mean over the batch.
|
||||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||||
mean_kld = (
|
mean_kld = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
|
||||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
loss_dict["kld_loss_per_item"] = mean_kld
|
||||||
)
|
|
||||||
|
mean_kld = mean_kld.mean()
|
||||||
loss_dict["kld_loss"] = mean_kld.item()
|
loss_dict["kld_loss"] = mean_kld.item()
|
||||||
|
|
||||||
loss = l1_loss + mean_kld * self.config.kl_weight
|
loss = l1_loss + mean_kld * self.config.kl_weight
|
||||||
|
loss_dict["loss_per_item"] = (
|
||||||
|
loss_dict["l1_loss_per_item"] + loss_dict["kld_loss_per_item"] * self.config.kl_weight
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
loss = l1_loss
|
loss = l1_loss
|
||||||
|
|
||||||
|
|||||||
159
lerobot/scripts/save_inference.py
Normal file
159
lerobot/scripts/save_inference.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from pprint import pformat
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
from lerobot.common.utils.utils import (
|
||||||
|
auto_select_torch_device,
|
||||||
|
init_logging,
|
||||||
|
is_amp_available,
|
||||||
|
is_torch_device_available,
|
||||||
|
)
|
||||||
|
from lerobot.configs import parser
|
||||||
|
from lerobot.configs.default import DatasetConfig
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SaveInferenceConfig:
|
||||||
|
dataset: DatasetConfig
|
||||||
|
# Delete the output directory if it exists already.
|
||||||
|
force_override: bool = False
|
||||||
|
|
||||||
|
batch_size: int = 16
|
||||||
|
num_workers: int = 4
|
||||||
|
|
||||||
|
policy: PreTrainedConfig | None = None
|
||||||
|
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
|
||||||
|
device: str | None = None # cuda | cpu | mps
|
||||||
|
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||||
|
# automatic gradient scaling is used.
|
||||||
|
use_amp: bool | None = None
|
||||||
|
|
||||||
|
output_dir: str | Path | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_path_fields__(cls) -> list[str]:
|
||||||
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||||
|
return ["policy"]
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||||
|
policy_path = parser.get_path_arg("policy")
|
||||||
|
if policy_path:
|
||||||
|
cli_overrides = parser.get_cli_overrides("policy")
|
||||||
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
|
self.policy.pretrained_path = policy_path
|
||||||
|
|
||||||
|
# When no device or use_amp are given, use the one from training config.
|
||||||
|
if self.device is None or self.use_amp is None:
|
||||||
|
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
|
||||||
|
if self.device is None:
|
||||||
|
self.device = train_cfg.device
|
||||||
|
if self.use_amp is None:
|
||||||
|
self.use_amp = train_cfg.use_amp
|
||||||
|
|
||||||
|
# Automatically switch to available device if necessary
|
||||||
|
if not is_torch_device_available(self.device):
|
||||||
|
auto_device = auto_select_torch_device()
|
||||||
|
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||||
|
self.device = auto_device
|
||||||
|
|
||||||
|
# Automatically deactivate AMP if necessary
|
||||||
|
if self.use_amp and not is_amp_available(self.device):
|
||||||
|
logging.warning(
|
||||||
|
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||||
|
)
|
||||||
|
self.use_amp = False
|
||||||
|
|
||||||
|
|
||||||
|
@parser.wrap()
|
||||||
|
def save_inference(cfg: SaveInferenceConfig):
|
||||||
|
init_logging()
|
||||||
|
logging.info(pformat(asdict(cfg)))
|
||||||
|
|
||||||
|
dataset = make_dataset(cfg)
|
||||||
|
policy = make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
|
||||||
|
|
||||||
|
output_dir = cfg.output_dir
|
||||||
|
if output_dir is None:
|
||||||
|
# Create a temporary directory that will be automatically cleaned up
|
||||||
|
output_dir = tempfile.mkdtemp(prefix="lerobot_save_inference_")
|
||||||
|
|
||||||
|
elif Path(output_dir).exists():
|
||||||
|
if cfg.force_override:
|
||||||
|
shutil.rmtree(cfg.output_dir)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Output directory already exists: {cfg.output_dir}")
|
||||||
|
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=cfg.num_workers,
|
||||||
|
batch_size=cfg.batch_size,
|
||||||
|
pin_memory=cfg.device != "cpu",
|
||||||
|
drop_last=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
policy.train()
|
||||||
|
|
||||||
|
episode_indices = []
|
||||||
|
frame_indices = []
|
||||||
|
feats = {}
|
||||||
|
|
||||||
|
for batch in tqdm.tqdm(dataloader):
|
||||||
|
for key in batch:
|
||||||
|
if isinstance(batch[key], torch.Tensor):
|
||||||
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
|
with torch.no_grad(), torch.autocast(device_type=cfg.device) if cfg.use_amp else nullcontext():
|
||||||
|
_, output_dict = policy.forward(batch)
|
||||||
|
|
||||||
|
bsize = batch["episode_index"].shape[0]
|
||||||
|
episode_indices.append(batch["episode_index"])
|
||||||
|
frame_indices.append(batch["frame_index"])
|
||||||
|
|
||||||
|
for key in output_dict:
|
||||||
|
if "loss_per_item" not in key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key not in feats:
|
||||||
|
feats[key] = []
|
||||||
|
|
||||||
|
if not (output_dict[key].ndim == 1 and output_dict[key].shape[0] == bsize):
|
||||||
|
raise ValueError(output_dict[key].shape)
|
||||||
|
|
||||||
|
feats[key].append(output_dict[key])
|
||||||
|
|
||||||
|
episode_indices = torch.cat(episode_indices)
|
||||||
|
frame_indices = torch.cat(frame_indices)
|
||||||
|
|
||||||
|
for key in feats:
|
||||||
|
feats[key] = torch.cat(feats[key])
|
||||||
|
|
||||||
|
# Find unique episode indices
|
||||||
|
unique_episodes = torch.unique(episode_indices)
|
||||||
|
|
||||||
|
for episode in unique_episodes:
|
||||||
|
ep_feats = {}
|
||||||
|
for key in feats:
|
||||||
|
ep_feats[key] = feats[key][episode_indices == episode].data.cpu()
|
||||||
|
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
torch.save(ep_feats, output_dir / f"output_features_episode_{episode}.pth")
|
||||||
|
|
||||||
|
print(f"Features can be found in: {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
save_inference()
|
||||||
@@ -65,6 +65,7 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import requests
|
import requests
|
||||||
|
import torch
|
||||||
from flask import Flask, redirect, render_template, request, url_for
|
from flask import Flask, redirect, render_template, request, url_for
|
||||||
|
|
||||||
from lerobot import available_datasets
|
from lerobot import available_datasets
|
||||||
@@ -80,6 +81,7 @@ def run_server(
|
|||||||
port: str,
|
port: str,
|
||||||
static_folder: Path,
|
static_folder: Path,
|
||||||
template_folder: Path,
|
template_folder: Path,
|
||||||
|
inference_dir: Path | None,
|
||||||
):
|
):
|
||||||
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
||||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||||
@@ -139,7 +141,14 @@ def run_server(
|
|||||||
)
|
)
|
||||||
|
|
||||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
||||||
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
|
def show_episode(
|
||||||
|
dataset_namespace,
|
||||||
|
dataset_name,
|
||||||
|
episode_id,
|
||||||
|
dataset=dataset,
|
||||||
|
episodes=episodes,
|
||||||
|
inference_dir=inference_dir,
|
||||||
|
):
|
||||||
repo_id = f"{dataset_namespace}/{dataset_name}"
|
repo_id = f"{dataset_namespace}/{dataset_name}"
|
||||||
try:
|
try:
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
@@ -158,7 +167,7 @@ def run_server(
|
|||||||
if major_version < 2:
|
if major_version < 2:
|
||||||
return "Make sure to convert your LeRobotDataset to v2 & above."
|
return "Make sure to convert your LeRobotDataset to v2 & above."
|
||||||
|
|
||||||
episode_data_csv_str, columns = get_episode_data(dataset, episode_id)
|
episode_data_csv_str, columns = get_episode_data(dataset, episode_id, inference_dir)
|
||||||
dataset_info = {
|
dataset_info = {
|
||||||
"repo_id": f"{dataset_namespace}/{dataset_name}",
|
"repo_id": f"{dataset_namespace}/{dataset_name}",
|
||||||
"num_samples": dataset.num_frames
|
"num_samples": dataset.num_frames
|
||||||
@@ -228,7 +237,9 @@ def get_ep_csv_fname(episode_id: int):
|
|||||||
return ep_csv_fname
|
return ep_csv_fname
|
||||||
|
|
||||||
|
|
||||||
def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
|
def get_episode_data(
|
||||||
|
dataset: LeRobotDataset | IterableNamespace, episode_index, inference_dir: Path | None = None
|
||||||
|
):
|
||||||
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
|
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
|
||||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||||
columns = []
|
columns = []
|
||||||
@@ -279,7 +290,15 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
|||||||
np.expand_dims(data["timestamp"], axis=1),
|
np.expand_dims(data["timestamp"], axis=1),
|
||||||
*[np.vstack(data[col]) for col in selected_columns[1:]],
|
*[np.vstack(data[col]) for col in selected_columns[1:]],
|
||||||
)
|
)
|
||||||
).tolist()
|
)
|
||||||
|
|
||||||
|
if inference_dir is not None:
|
||||||
|
feats = torch.load(inference_dir / f"output_features_episode_{episode_index}.pth")
|
||||||
|
for key in feats:
|
||||||
|
header.append(key.replace("loss_per_item", "loss"))
|
||||||
|
rows = np.concatenate([rows, feats[key][:, None]], axis=1)
|
||||||
|
|
||||||
|
rows = rows.tolist()
|
||||||
|
|
||||||
# Convert data to CSV string
|
# Convert data to CSV string
|
||||||
csv_buffer = StringIO()
|
csv_buffer = StringIO()
|
||||||
@@ -332,6 +351,7 @@ def visualize_dataset_html(
|
|||||||
host: str = "127.0.0.1",
|
host: str = "127.0.0.1",
|
||||||
port: int = 9090,
|
port: int = 9090,
|
||||||
force_override: bool = False,
|
force_override: bool = False,
|
||||||
|
inference_dir: Path | None = None,
|
||||||
) -> Path | None:
|
) -> Path | None:
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
@@ -372,7 +392,7 @@ def visualize_dataset_html(
|
|||||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
||||||
|
|
||||||
if serve:
|
if serve:
|
||||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
run_server(dataset, episodes, host, port, static_dir, template_dir, inference_dir)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -439,6 +459,12 @@ def main():
|
|||||||
default=0,
|
default=0,
|
||||||
help="Delete the output directory if it exists already.",
|
help="Delete the output directory if it exists already.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--inference-dir",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
kwargs = vars(args)
|
kwargs = vars(args)
|
||||||
|
|||||||
Reference in New Issue
Block a user