forked from tangger/lerobot
Compare commits
4 Commits
test/robot
...
user/rcade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f32d75b58 | ||
|
|
a66a792029 | ||
|
|
b6aedcd9a5 | ||
|
|
121030cca7 |
@@ -155,25 +155,34 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
l1_loss *= ~batch["action_is_pad"].unsqueeze(-1)
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
bsize, seqlen, num_motors = l1_loss.shape
|
||||
output_dict = {
|
||||
"l1_loss": l1_loss.mean().item(),
|
||||
"l1_loss_per_item": l1_loss.view(bsize, seqlen * num_motors).mean(dim=1),
|
||||
"action": self.unnormalize_outputs({"action": actions_hat})["action"],
|
||||
}
|
||||
if self.config.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
mean_kld = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
|
||||
output_dict["kld_loss_per_item"] = mean_kld
|
||||
|
||||
mean_kld = mean_kld.mean()
|
||||
output_dict["kld_loss"] = mean_kld.item()
|
||||
|
||||
loss = l1_loss + mean_kld * self.config.kl_weight
|
||||
output_dict["loss_per_item"] = (
|
||||
output_dict["l1_loss_per_item"] + output_dict["kld_loss_per_item"] * self.config.kl_weight
|
||||
)
|
||||
else:
|
||||
loss = l1_loss
|
||||
|
||||
return loss, loss_dict
|
||||
return loss, output_dict
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
|
||||
162
lerobot/scripts/save_inference.py
Normal file
162
lerobot/scripts/save_inference.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import (
|
||||
auto_select_torch_device,
|
||||
init_logging,
|
||||
is_amp_available,
|
||||
is_torch_device_available,
|
||||
)
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaveInferenceConfig:
|
||||
dataset: DatasetConfig
|
||||
# Delete the output directory if it exists already.
|
||||
force_override: bool = False
|
||||
|
||||
batch_size: int = 16
|
||||
num_workers: int = 4
|
||||
|
||||
policy: PreTrainedConfig | None = None
|
||||
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
|
||||
device: str | None = None # cuda | cpu | mps
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool | None = None
|
||||
|
||||
output_dir: str | Path | None = None
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
# When no device or use_amp are given, use the one from training config.
|
||||
if self.device is None or self.use_amp is None:
|
||||
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
|
||||
if self.device is None:
|
||||
self.device = train_cfg.device
|
||||
if self.use_amp is None:
|
||||
self.use_amp = train_cfg.use_amp
|
||||
|
||||
# Automatically switch to available device if necessary
|
||||
if not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
if self.use_amp and not is_amp_available(self.device):
|
||||
logging.warning(
|
||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||
)
|
||||
self.use_amp = False
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def save_inference(cfg: SaveInferenceConfig):
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
|
||||
|
||||
output_dir = cfg.output_dir
|
||||
if output_dir is None:
|
||||
# Create a temporary directory that will be automatically cleaned up
|
||||
output_dir = tempfile.mkdtemp(prefix="lerobot_save_inference_")
|
||||
|
||||
elif Path(output_dir).exists() and cfg.force_override:
|
||||
shutil.rmtree(cfg.output_dir)
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
episode_indices = []
|
||||
frame_indices = []
|
||||
feats = {}
|
||||
|
||||
for batch in tqdm.tqdm(dataloader):
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=cfg.device) if cfg.use_amp else nullcontext():
|
||||
_, output_dict = policy.forward(batch)
|
||||
|
||||
batch_size = batch["episode_index"].shape[0]
|
||||
episode_indices.append(batch["episode_index"])
|
||||
frame_indices.append(batch["frame_index"])
|
||||
|
||||
for key, value in output_dict.items():
|
||||
if not isinstance(value, torch.Tensor) or value.shape[0] != batch_size:
|
||||
print(f"Skipping {key}")
|
||||
continue
|
||||
|
||||
if key not in feats:
|
||||
feats[key] = []
|
||||
|
||||
feats[key].append(value)
|
||||
|
||||
episode_indices = torch.cat(episode_indices).cpu()
|
||||
frame_indices = torch.cat(frame_indices).cpu()
|
||||
|
||||
# TODO(rcadene): use collate?
|
||||
for key, value in feats.items():
|
||||
if isinstance(value[0], (float, int)):
|
||||
feats[key] = torch.tensor(value)
|
||||
elif isinstance(value[0], torch.Tensor):
|
||||
feats[key] = torch.cat(value, dim=0).cpu()
|
||||
elif isinstance(value[0], str):
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(f"{key}: {value}")
|
||||
|
||||
# Find unique episode indices
|
||||
unique_episodes = torch.unique(episode_indices)
|
||||
|
||||
for episode in unique_episodes:
|
||||
ep_feats = {}
|
||||
for key in feats:
|
||||
ep_feats[key] = feats[key][episode_indices == episode]
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(ep_feats, output_dir / f"output_features_episode_{episode}.pth")
|
||||
|
||||
print(f"Features can be found in: {output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
save_inference()
|
||||
@@ -65,6 +65,7 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import requests
|
||||
import torch
|
||||
from flask import Flask, redirect, render_template, request, url_for
|
||||
|
||||
from lerobot import available_datasets
|
||||
@@ -80,6 +81,7 @@ def run_server(
|
||||
port: str,
|
||||
static_folder: Path,
|
||||
template_folder: Path,
|
||||
inference_dir: Path | None,
|
||||
):
|
||||
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||
@@ -139,7 +141,14 @@ def run_server(
|
||||
)
|
||||
|
||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
|
||||
def show_episode(
|
||||
dataset_namespace,
|
||||
dataset_name,
|
||||
episode_id,
|
||||
dataset=dataset,
|
||||
episodes=episodes,
|
||||
inference_dir=inference_dir,
|
||||
):
|
||||
repo_id = f"{dataset_namespace}/{dataset_name}"
|
||||
try:
|
||||
if dataset is None:
|
||||
@@ -158,7 +167,7 @@ def run_server(
|
||||
if major_version < 2:
|
||||
return "Make sure to convert your LeRobotDataset to v2 & above."
|
||||
|
||||
episode_data_csv_str, columns = get_episode_data(dataset, episode_id)
|
||||
episode_data_csv_str, columns = get_episode_data(dataset, episode_id, inference_dir)
|
||||
dataset_info = {
|
||||
"repo_id": f"{dataset_namespace}/{dataset_name}",
|
||||
"num_samples": dataset.num_frames
|
||||
@@ -228,7 +237,9 @@ def get_ep_csv_fname(episode_id: int):
|
||||
return ep_csv_fname
|
||||
|
||||
|
||||
def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
|
||||
def get_episode_data(
|
||||
dataset: LeRobotDataset | IterableNamespace, episode_index, inference_dir: Path | None = None
|
||||
):
|
||||
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
columns = []
|
||||
@@ -279,7 +290,33 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
np.expand_dims(data["timestamp"], axis=1),
|
||||
*[np.vstack(data[col]) for col in selected_columns[1:]],
|
||||
)
|
||||
).tolist()
|
||||
)
|
||||
|
||||
if inference_dir is not None:
|
||||
feats = torch.load(inference_dir / f"output_features_episode_{episode_index}.pth")
|
||||
for key in feats:
|
||||
if "loss_per_item" in key:
|
||||
if feats[key].ndim != 1:
|
||||
raise ValueError()
|
||||
|
||||
header.append(key.replace("loss_per_item", "loss"))
|
||||
rows = np.concatenate([rows, feats[key][:, None]], axis=1)
|
||||
|
||||
elif key == "action":
|
||||
if feats[key].ndim != 3:
|
||||
raise ValueError()
|
||||
|
||||
next_action = feats[key][:, 0, :]
|
||||
num_motors = next_action.shape[1]
|
||||
|
||||
for i in range(num_motors):
|
||||
header.append(f"action_{i}")
|
||||
|
||||
rows = np.concatenate([rows, next_action], axis=1)
|
||||
else:
|
||||
raise NotImplementedError(key)
|
||||
|
||||
rows = rows.tolist()
|
||||
|
||||
# Convert data to CSV string
|
||||
csv_buffer = StringIO()
|
||||
@@ -332,6 +369,7 @@ def visualize_dataset_html(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 9090,
|
||||
force_override: bool = False,
|
||||
inference_dir: Path | None = None,
|
||||
) -> Path | None:
|
||||
init_logging()
|
||||
|
||||
@@ -372,7 +410,7 @@ def visualize_dataset_html(
|
||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
||||
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir, inference_dir)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -439,6 +477,12 @@ def main():
|
||||
default=0,
|
||||
help="Delete the output directory if it exists already.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--inference-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
|
||||
Reference in New Issue
Block a user