Remove offline training, refactor train.py and logging/checkpointing (#670)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2025-02-11 10:36:06 +01:00
committed by GitHub
parent 334deb985d
commit 90e099b39f
40 changed files with 1515 additions and 935 deletions

View File

@@ -15,58 +15,61 @@
# limitations under the License.
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from copy import deepcopy
from dataclasses import asdict
from pprint import pformat
from threading import Lock
from typing import Any
import numpy as np
import torch
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.optim.factory import load_training_state, make_optimizer_and_scheduler
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.common.utils.utils import (
format_big_number,
get_safe_dtype,
get_safe_torch_device,
has_method,
init_logging,
set_global_seed,
)
from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.eval import eval_policy
def update_policy(
policy,
batch,
optimizer,
grad_clip_norm,
train_metrics: MetricsTracker,
policy: PreTrainedPolicy,
batch: Any,
optimizer: Optimizer,
grad_clip_norm: float,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
lock=None,
):
"""Returns a dictionary of items for logging."""
) -> tuple[MetricsTracker, dict]:
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
output_dict = policy.forward(batch)
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"]
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
@@ -87,9 +90,6 @@ def update_policy(
optimizer.zero_grad()
if hasattr(policy, "update_ema_modules"):
policy.update_ema_modules()
# Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None:
lr_scheduler.step()
@@ -98,113 +98,34 @@ def update_policy(
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"],
"update_s": time.perf_counter() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"},
}
info.update({k: v for k, v in output_dict.items() if k not in info})
return info
def log_train_info(
logger: Logger, info: dict, step: int, cfg: TrainPipelineConfig, dataset: LeRobotDataset, is_online: bool
):
loss = info["loss"]
grad_norm = info["grad_norm"]
lr = info["lr"]
update_s = info["update_s"]
dataloading_s = info["dataloading_s"]
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.batch_size
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_frames
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
f"smpl:{format_big_number(num_samples)}",
# number of episodes seen during training
f"ep:{format_big_number(num_episodes)}",
# number of time all unique samples are seen
f"epch:{num_epochs:.2f}",
f"loss:{loss:.3f}",
f"grdn:{grad_norm:.3f}",
f"lr:{lr:0.1e}",
# in seconds
f"updt_s:{update_s:.3f}",
f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io
]
logging.info(" ".join(log_items))
info["step"] = step
info["num_samples"] = num_samples
info["num_episodes"] = num_episodes
info["num_epochs"] = num_epochs
info["is_online"] = is_online
logger.log_dict(info, step, mode="train")
def log_eval_info(logger, info, step, cfg, dataset, is_online):
eval_s = info["eval_s"]
avg_sum_reward = info["avg_sum_reward"]
pc_success = info["pc_success"]
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.batch_size
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_frames
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
f"smpl:{format_big_number(num_samples)}",
# number of episodes seen during training
f"ep:{format_big_number(num_episodes)}",
# number of time all unique samples are seen
f"epch:{num_epochs:.2f}",
f"∑rwrd:{avg_sum_reward:.3f}",
f"success:{pc_success:.1f}%",
f"eval_s:{eval_s:.3f}",
]
logging.info(" ".join(log_items))
info["step"] = step
info["num_samples"] = num_samples
info["num_episodes"] = num_episodes
info["num_epochs"] = num_epochs
info["is_online"] = is_online
logger.log_dict(info, step, mode="eval")
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
train_metrics.lr = optimizer.param_groups[0]["lr"]
train_metrics.update_s = time.perf_counter() - start_time
return train_metrics, output_dict
@parser.wrap()
def train(cfg: TrainPipelineConfig):
cfg.validate()
logging.info(pformat(cfg.to_dict()))
logging.info(pformat(asdict(cfg)))
# log metrics to terminal and wandb
logger = Logger(cfg)
if cfg.wandb.enable and cfg.wandb.project:
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None:
set_global_seed(cfg.seed)
set_seed(cfg.seed)
# Check device is available
device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("Creating dataset")
offline_dataset = make_dataset(cfg)
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
@@ -218,7 +139,7 @@ def train(cfg: TrainPipelineConfig):
policy = make_policy(
cfg=cfg.policy,
device=device,
ds_meta=offline_dataset.meta,
ds_meta=dataset.meta,
)
logging.info("Creating optimizer and scheduler")
@@ -233,65 +154,29 @@ def train(cfg: TrainPipelineConfig):
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(cfg.output_dir)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline.steps=} ({format_big_number(cfg.offline.steps)})")
logging.info(f"{cfg.online.steps=}")
logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# Note: this helper will be used in offline and online training loops.
def evaluate_and_checkpoint_if_needed(step: int, is_online: bool):
_num_digits = max(6, len(str(cfg.offline.steps + cfg.online.steps)))
step_identifier = f"{step:0{_num_digits}d}"
if cfg.env is not None and cfg.eval_freq > 0 and step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_info = eval_policy(
eval_env,
policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_identifier}",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_online=is_online)
if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")
if cfg.save_checkpoint and (
step % cfg.save_freq == 0 or step == cfg.offline.steps + cfg.online.steps
):
logging.info(f"Checkpoint policy after step {step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
logger.save_checkpoint(
step,
step_identifier,
policy,
optimizer,
lr_scheduler,
)
logging.info("Resume training")
# create dataloader for offline training
if getattr(cfg.policy, "drop_n_last_frames", None):
if hasattr(cfg.policy, "drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
offline_dataset.episode_data_index,
dataset.episode_data_index,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
else:
shuffle = True
sampler = None
dataloader = torch.utils.data.DataLoader(
offline_dataset,
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,
@@ -303,23 +188,30 @@ def train(cfg: TrainPipelineConfig):
policy.train()
if hasattr(policy, "init_ema_modules"):
policy.init_ema_modules()
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"),
"lr": AverageMeter("lr", ":0.1e"),
"update_s": AverageMeter("updt_s", ":.3f"),
"dataloading_s": AverageMeter("data_s", ":.3f"),
}
offline_step = 0
for _ in range(step, cfg.offline.steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
train_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
)
logging.info("Start offline training on a fixed dataset")
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time
train_tracker.dataloading_s = time.perf_counter() - start_time
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=True)
train_info = update_policy(
train_tracker, output_dict = update_policy(
train_tracker,
policy,
batch,
optimizer,
@@ -329,231 +221,57 @@ def train(cfg: TrainPipelineConfig):
use_amp=cfg.use_amp,
)
train_info["dataloading_s"] = dataloading_s
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1, is_online=False)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1
offline_step += 1 # noqa: SIM113
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
if cfg.online.steps == 0:
if eval_env:
eval_env.close()
logging.info("End of training")
return
if is_log_step:
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = {**train_tracker.to_dict(), **output_dict}
wandb_logger.log_dict(wandb_log_dict)
train_tracker.reset_averages()
# Online training.
if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
# Create an env dedicated to online episodes collection from policy rollout.
online_env = make_env(cfg.env, n_envs=cfg.online.rollout_batch_size)
delta_timestamps = resolve_delta_timestamps(cfg.policy, offline_dataset.meta)
online_buffer_path = logger.log_dir / "online_buffer"
if cfg.resume and not online_buffer_path.exists():
# If we are resuming a run, we default to the data shapes and buffer capacity from the saved online
# buffer.
logging.warning(
"When online training is resumed, we load the latest online buffer from the prior run, "
"and this might not coincide with the state of the buffer as it was at the moment the checkpoint "
"was made. This is because the online buffer is updated on disk during training, independently "
"of our explicit checkpointing mechanisms."
)
online_dataset = OnlineBuffer(
online_buffer_path,
data_spec={
**{
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
for key, ft in policy.config.input_features.items()
},
**{
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
for key, ft in policy.config.output_features.items()
},
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
"next.done": {"shape": (), "dtype": np.dtype("?")},
"task_index": {"shape": (), "dtype": np.dtype("int64")},
# FIXME: 'task' is a string
# "task": {"shape": (), "dtype": np.dtype("?")},
# FIXME: 'next.success' is expected by pusht env but not xarm
"next.success": {"shape": (), "dtype": np.dtype("?")},
},
buffer_capacity=cfg.online.buffer_capacity,
fps=online_env.unwrapped.metadata["render_fps"],
delta_timestamps=delta_timestamps,
)
# If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this
# makes it possible to do online rollouts in parallel with training updates).
online_rollout_policy = deepcopy(policy) if cfg.online.do_rollout_async else policy
# Create dataloader for online training.
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
sampler_weights = compute_sampler_weights(
offline_dataset,
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
online_dataset=online_dataset,
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
# this final observation in the offline datasets, but we might add them in future.
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.online.sampling_ratio,
)
sampler = torch.utils.data.WeightedRandomSampler(
sampler_weights,
num_samples=len(concat_dataset),
replacement=True,
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=True,
)
dl_iter = cycle(dataloader)
if cfg.online.do_rollout_async:
# Lock and thread pool executor for asynchronous online rollouts.
lock = Lock()
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
# parallelization of rollouts is handled within the job.
executor = ThreadPoolExecutor(max_workers=1)
else:
lock = None
online_step = 0
online_rollout_s = 0 # time take to do online rollout
update_online_buffer_s = 0 # time taken to update the online buffer with the online rollout data
# Time taken waiting for the online buffer to finish being updated. This is relevant when using the async
# online rollout option.
await_update_online_buffer_s = 0
rollout_start_seed = cfg.online.env_seed
while True:
if online_step == cfg.online.steps:
break
if online_step == 0:
logging.info("Start online training by interacting with environment")
def sample_trajectory_and_update_buffer():
nonlocal rollout_start_seed
with lock if lock is not None else nullcontext():
online_rollout_policy.load_state_dict(policy.state_dict())
online_rollout_policy.eval()
start_rollout_time = time.perf_counter()
with torch.no_grad():
if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_info = eval_policy(
online_env,
online_rollout_policy,
n_episodes=cfg.online.rollout_n_episodes,
max_episodes_rendered=min(10, cfg.online.rollout_n_episodes),
videos_dir=logger.log_dir / "online_rollout_videos",
return_episode_data=True,
start_seed=(rollout_start_seed := (rollout_start_seed + cfg.batch_size) % 1000000),
eval_env,
policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
online_rollout_s = time.perf_counter() - start_rollout_time
if len(offline_dataset.meta.tasks) > 1:
raise NotImplementedError("Add support for multi task.")
# TODO(rcadene, aliberts): Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
total_num_frames = eval_info["episodes"]["index"].shape[0]
eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64)
eval_info["episodes"]["task"] = ["do the thing"] * total_num_frames
with lock if lock is not None else nullcontext():
start_update_buffer_time = time.perf_counter()
online_dataset.add_data(eval_info["episodes"])
# Update the concatenated dataset length used during sampling.
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# Update the sampling weights.
sampler.weights = compute_sampler_weights(
offline_dataset,
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
online_dataset=online_dataset,
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
# this final observation in the offline datasets, but we might add them in future.
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.online.sampling_ratio,
)
sampler.num_frames = len(concat_dataset)
update_online_buffer_s = time.perf_counter() - start_update_buffer_time
return online_rollout_s, update_online_buffer_s
if lock is None:
online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()
else:
future = executor.submit(sample_trajectory_and_update_buffer)
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
if len(online_dataset) <= cfg.online.buffer_seed_size:
online_rollout_s, update_online_buffer_s = future.result()
if len(online_dataset) <= cfg.online.buffer_seed_size:
logging.info(f"Seeding online buffer: {len(online_dataset)}/{cfg.online.buffer_seed_size}")
continue
policy.train()
for _ in range(cfg.online.steps_between_rollouts):
with lock if lock is not None else nullcontext():
start_time = time.perf_counter()
batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time
for key in batch:
if isinstance(batch[key], torch.Tensor):
dtype = get_safe_dtype(batch[key].dtype, device)
batch[key] = batch[key].to(device=device, dtype=dtype, non_blocking=True)
train_info = update_policy(
policy,
batch,
optimizer,
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
lock=lock,
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"),
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
)
train_info["dataloading_s"] = dataloading_s
train_info["online_rollout_s"] = online_rollout_s
train_info["update_online_buffer_s"] = update_online_buffer_s
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
with lock if lock is not None else nullcontext():
train_info["online_buffer_size"] = len(online_dataset)
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1, is_online=True)
step += 1
online_step += 1
# If we're doing async rollouts, we should now wait until we've completed them before proceeding
# to do the next batch of rollouts.
if cfg.online.do_rollout_async and future.running():
start = time.perf_counter()
online_rollout_s, update_online_buffer_s = future.result()
await_update_online_buffer_s = time.perf_counter() - start
if online_step >= cfg.online.steps:
break
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
logging.info(eval_tracker)
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
if eval_env:
eval_env.close()