Simplify configs (#550)

Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
Simon Alibert
2025-01-31 13:57:37 +01:00
committed by GitHub
parent 1ee1acf8ad
commit 3c0a209f9f
119 changed files with 5761 additions and 5466 deletions

View File

@@ -18,92 +18,36 @@ import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from dataclasses import asdict
from pprint import pformat
from threading import Lock
import hydra
import numpy as np
import torch
from deepdiff import DeepDiff
from omegaconf import DictConfig, ListConfig, OmegaConf
from termcolor import colored
from torch import nn
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.optim.factory import load_training_state, make_optimizer_and_scheduler
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
has_method,
init_logging,
set_global_seed,
)
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.eval import eval_policy
def make_optimizer_and_scheduler(cfg, policy):
if cfg.policy.name == "act":
optimizer_params_dicts = [
{
"params": [
p
for n, p in policy.named_parameters()
if not n.startswith("model.backbone") and p.requires_grad
]
},
{
"params": [
p
for n, p in policy.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": cfg.training.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
)
lr_scheduler = None
elif cfg.policy.name == "diffusion":
optimizer = torch.optim.Adam(
policy.diffusion.parameters(),
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.training.offline_steps,
)
elif policy.name == "tdmpc":
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None
elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
optimizer = VQBeTOptimizer(policy, cfg)
lr_scheduler = VQBeTScheduler(optimizer, cfg)
else:
raise NotImplementedError()
return optimizer, lr_scheduler
def update_policy(
policy,
batch,
@@ -145,7 +89,7 @@ def update_policy(
if lr_scheduler is not None:
lr_scheduler.step()
if isinstance(policy, PolicyWithUpdate):
if has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
@@ -161,7 +105,9 @@ def update_policy(
return info
def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
def log_train_info(
logger: Logger, info: dict, step: int, cfg: TrainPipelineConfig, dataset: LeRobotDataset, is_online: bool
):
loss = info["loss"]
grad_norm = info["grad_norm"]
lr = info["lr"]
@@ -170,7 +116,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.training.batch_size
num_samples = (step + 1) * cfg.batch_size
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_frames
@@ -207,7 +153,7 @@ def log_eval_info(logger, info, step, cfg, dataset, is_online):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.training.batch_size
num_samples = (step + 1) * cfg.batch_size
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_frames
@@ -234,74 +180,17 @@ def log_eval_info(logger, info, step, cfg, dataset, is_online):
logger.log_dict(info, step, mode="eval")
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None:
raise NotImplementedError()
if job_name is None:
raise NotImplementedError()
@parser.wrap()
def train(cfg: TrainPipelineConfig):
cfg.validate()
init_logging()
logging.info(pformat(OmegaConf.to_container(cfg)))
if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
# to check for any differences between the provided config and the checkpoint's config.
if cfg.resume:
if not Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
color="yellow",
attrs=["bold"],
)
)
# Get the configuration file from the last checkpoint.
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
# Check for differences between the checkpoint configuration and provided configuration.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
resolve_delta_timestamps(cfg)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
# Log a warning about differences between the checkpoint configuration and the provided
# configuration.
if len(diff) > 0:
logging.warning(
"At least one difference was detected between the checkpoint configuration and "
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
"takes precedence.",
)
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
cfg = checkpoint_cfg
cfg.resume = True
elif Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If "
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
)
if cfg.eval.batch_size > cfg.eval.n_episodes:
raise ValueError(
"The eval batch size is greater than the number of eval episodes "
f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} "
f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. "
"This might significantly slow down evaluation. To fix this, you should update your command "
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), "
f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)."
)
logging.info(pformat(asdict(cfg)))
# log metrics to terminal and wandb
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
logger = Logger(cfg)
set_global_seed(cfg.seed)
if cfg.seed is not None:
set_global_seed(cfg.seed)
# Check device is available
device = get_safe_torch_device(cfg.device, log=True)
@@ -309,65 +198,58 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("make_dataset")
logging.info("Creating dataset")
offline_dataset = make_dataset(cfg)
if isinstance(offline_dataset, MultiLeRobotDataset):
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(offline_dataset.repo_id_to_index , indent=2)}"
)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.training.eval_freq > 0:
logging.info("make_env")
eval_env = make_env(cfg)
if cfg.eval_freq > 0 and cfg.env is not None:
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
logging.info("make_policy")
logging.info("Creating policy")
policy = make_policy(
hydra_cfg=cfg,
dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
cfg=cfg.policy,
device=device,
ds_meta=offline_dataset.meta,
)
assert isinstance(policy, nn.Module)
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(enabled=cfg.use_amp)
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
step = logger.load_last_training_state(optimizer, lr_scheduler)
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
logging.info(f"{cfg.training.online_steps=}")
log_output_dir(cfg.output_dir)
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline.steps=} ({format_big_number(cfg.offline.steps)})")
logging.info(f"{cfg.online.steps=}")
logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# Note: this helper will be used in offline and online training loops.
def evaluate_and_checkpoint_if_needed(step, is_online):
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
def evaluate_and_checkpoint_if_needed(step: int, is_online: bool):
_num_digits = max(6, len(str(cfg.offline.steps + cfg.online.steps)))
step_identifier = f"{step:0{_num_digits}d}"
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
if cfg.env is not None and cfg.eval_freq > 0 and step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
assert eval_env is not None
eval_info = eval_policy(
eval_env,
policy,
cfg.eval.n_episodes,
videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}",
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_identifier}",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
@@ -376,28 +258,27 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")
if cfg.training.save_checkpoint and (
step % cfg.training.save_freq == 0
or step == cfg.training.offline_steps + cfg.training.online_steps
if cfg.save_checkpoint and (
step % cfg.save_freq == 0 or step == cfg.offline.steps + cfg.online.steps
):
logging.info(f"Checkpoint policy after step {step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
logger.save_checkpoint(
step,
step_identifier,
policy,
optimizer,
lr_scheduler,
identifier=step_identifier,
)
logging.info("Resume training")
# create dataloader for offline training
if cfg.training.get("drop_n_last_frames"):
if getattr(cfg.policy, "drop_n_last_frames", None):
shuffle = False
sampler = EpisodeAwareSampler(
offline_dataset.episode_data_index,
drop_n_last_frames=cfg.training.drop_n_last_frames,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
else:
@@ -405,8 +286,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
sampler = None
dataloader = torch.utils.data.DataLoader(
offline_dataset,
num_workers=cfg.training.num_workers,
batch_size=cfg.training.batch_size,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,
sampler=sampler,
pin_memory=device.type != "cpu",
@@ -416,7 +297,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
policy.train()
offline_step = 0
for _ in range(step, cfg.training.offline_steps):
for _ in range(step, cfg.offline.steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
@@ -431,7 +312,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
policy,
batch,
optimizer,
cfg.training.grad_clip_norm,
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
@@ -439,7 +320,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
train_info["dataloading_s"] = dataloading_s
if step % cfg.training.log_freq == 0:
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
@@ -449,7 +330,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
step += 1
offline_step += 1 # noqa: SIM113
if cfg.training.online_steps == 0:
if cfg.online.steps == 0:
if eval_env:
eval_env.close()
logging.info("End of training")
@@ -458,8 +339,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Online training.
# Create an env dedicated to online episodes collection from policy rollout.
online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
resolve_delta_timestamps(cfg)
online_env = make_env(cfg.env, n_envs=cfg.online.rollout_batch_size)
delta_timestamps = resolve_delta_timestamps(cfg.policy, offline_dataset.meta)
online_buffer_path = logger.log_dir / "online_buffer"
if cfg.resume and not online_buffer_path.exists():
# If we are resuming a run, we default to the data shapes and buffer capacity from the saved online
@@ -473,31 +354,39 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
online_dataset = OnlineBuffer(
online_buffer_path,
data_spec={
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.input_shapes.items()},
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()},
**{
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
for key, ft in policy.config.input_features.items()
},
**{
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
for key, ft in policy.config.output_features.items()
},
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
"next.done": {"shape": (), "dtype": np.dtype("?")},
"task_index": {"shape": (), "dtype": np.dtype("int64")},
# FIXME: 'next.success' is expected by pusht env but not xarm
"next.success": {"shape": (), "dtype": np.dtype("?")},
},
buffer_capacity=cfg.training.online_buffer_capacity,
buffer_capacity=cfg.online.buffer_capacity,
fps=online_env.unwrapped.metadata["render_fps"],
delta_timestamps=cfg.training.delta_timestamps,
delta_timestamps=delta_timestamps,
)
# If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this
# makes it possible to do online rollouts in parallel with training updates).
online_rollout_policy = deepcopy(policy) if cfg.training.do_online_rollout_async else policy
online_rollout_policy = deepcopy(policy) if cfg.online.do_rollout_async else policy
# Create dataloader for online training.
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
sampler_weights = compute_sampler_weights(
offline_dataset,
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
online_dataset=online_dataset,
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
# this final observation in the offline datasets, but we might add them in future.
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.training.online_sampling_ratio,
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.online.sampling_ratio,
)
sampler = torch.utils.data.WeightedRandomSampler(
sampler_weights,
@@ -506,20 +395,22 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
batch_size=cfg.training.batch_size,
num_workers=cfg.training.num_workers,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=True,
)
dl_iter = cycle(dataloader)
# Lock and thread pool executor for asynchronous online rollouts. When asynchronous mode is disabled,
# these are still used but effectively do nothing.
lock = Lock()
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
# parallelization of rollouts is handled within the job.
executor = ThreadPoolExecutor(max_workers=1)
if cfg.online.do_rollout_async:
# Lock and thread pool executor for asynchronous online rollouts.
lock = Lock()
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
# parallelization of rollouts is handled within the job.
executor = ThreadPoolExecutor(max_workers=1)
else:
lock = None
online_step = 0
online_rollout_s = 0 # time take to do online rollout
@@ -527,10 +418,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Time taken waiting for the online buffer to finish being updated. This is relevant when using the async
# online rollout option.
await_update_online_buffer_s = 0
rollout_start_seed = cfg.training.online_env_seed
rollout_start_seed = cfg.online.env_seed
while True:
if online_step == cfg.training.online_steps:
if online_step == cfg.online.steps:
break
if online_step == 0:
@@ -538,25 +429,33 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
def sample_trajectory_and_update_buffer():
nonlocal rollout_start_seed
with lock:
with lock if lock is not None else nullcontext():
online_rollout_policy.load_state_dict(policy.state_dict())
online_rollout_policy.eval()
start_rollout_time = time.perf_counter()
with torch.no_grad():
eval_info = eval_policy(
online_env,
online_rollout_policy,
n_episodes=cfg.training.online_rollout_n_episodes,
max_episodes_rendered=min(10, cfg.training.online_rollout_n_episodes),
n_episodes=cfg.online.rollout_n_episodes,
max_episodes_rendered=min(10, cfg.online.rollout_n_episodes),
videos_dir=logger.log_dir / "online_rollout_videos",
return_episode_data=True,
start_seed=(
rollout_start_seed := (rollout_start_seed + cfg.training.batch_size) % 1000000
),
start_seed=(rollout_start_seed := (rollout_start_seed + cfg.batch_size) % 1000000),
)
online_rollout_s = time.perf_counter() - start_rollout_time
with lock:
if len(offline_dataset.meta.tasks) > 1:
raise NotImplementedError("Add support for multi task.")
# Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
total_num_frames = eval_info["episodes"]["index"].shape[0]
eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64)
with lock if lock is not None else nullcontext():
start_update_buffer_time = time.perf_counter()
online_dataset.add_data(eval_info["episodes"])
@@ -566,12 +465,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Update the sampling weights.
sampler.weights = compute_sampler_weights(
offline_dataset,
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
online_dataset=online_dataset,
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
# this final observation in the offline datasets, but we might add them in future.
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.training.online_sampling_ratio,
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.online.sampling_ratio,
)
sampler.num_frames = len(concat_dataset)
@@ -579,36 +478,34 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
return online_rollout_s, update_online_buffer_s
future = executor.submit(sample_trajectory_and_update_buffer)
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
if (
not cfg.training.do_online_rollout_async
or len(online_dataset) <= cfg.training.online_buffer_seed_size
):
online_rollout_s, update_online_buffer_s = future.result()
if lock is None:
online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()
else:
future = executor.submit(sample_trajectory_and_update_buffer)
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
if len(online_dataset) <= cfg.online.buffer_seed_size:
online_rollout_s, update_online_buffer_s = future.result()
if len(online_dataset) <= cfg.training.online_buffer_seed_size:
logging.info(
f"Seeding online buffer: {len(online_dataset)}/{cfg.training.online_buffer_seed_size}"
)
if len(online_dataset) <= cfg.online.buffer_seed_size:
logging.info(f"Seeding online buffer: {len(online_dataset)}/{cfg.online.buffer_seed_size}")
continue
policy.train()
for _ in range(cfg.training.online_steps_between_rollouts):
with lock:
for _ in range(cfg.online.steps_between_rollouts):
with lock if lock is not None else nullcontext():
start_time = time.perf_counter()
batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
batch[key] = batch[key].to(device, non_blocking=True)
train_info = update_policy(
policy,
batch,
optimizer,
cfg.training.grad_clip_norm,
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
@@ -619,10 +516,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
train_info["online_rollout_s"] = online_rollout_s
train_info["update_online_buffer_s"] = update_online_buffer_s
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
with lock:
with lock if lock is not None else nullcontext():
train_info["online_buffer_size"] = len(online_dataset)
if step % cfg.training.log_freq == 0:
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
@@ -634,12 +531,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# If we're doing async rollouts, we should now wait until we've completed them before proceeding
# to do the next batch of rollouts.
if future.running():
if cfg.online.do_rollout_async and future.running():
start = time.perf_counter()
online_rollout_s, update_online_buffer_s = future.result()
await_update_online_buffer_s = time.perf_counter() - start
if online_step >= cfg.training.online_steps:
if online_step >= cfg.online.steps:
break
if eval_env:
@@ -647,23 +544,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("End of training")
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
if __name__ == "__main__":
train_cli()
init_logging()
train()