forked from tangger/lerobot
Simplify configs (#550)
Co-authored-by: Remi <remi.cadene@huggingface.co> Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
@@ -18,92 +18,36 @@ import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from threading import Lock
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.amp import GradScaler
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.optim.factory import load_training_state, make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
has_method,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def make_optimizer_and_scheduler(cfg, policy):
|
||||
if cfg.policy.name == "act":
|
||||
optimizer_params_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in policy.named_parameters()
|
||||
if not n.startswith("model.backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in policy.named_parameters()
|
||||
if n.startswith("model.backbone") and p.requires_grad
|
||||
],
|
||||
"lr": cfg.training.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(
|
||||
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
|
||||
)
|
||||
lr_scheduler = None
|
||||
elif cfg.policy.name == "diffusion":
|
||||
optimizer = torch.optim.Adam(
|
||||
policy.diffusion.parameters(),
|
||||
cfg.training.lr,
|
||||
cfg.training.adam_betas,
|
||||
cfg.training.adam_eps,
|
||||
cfg.training.adam_weight_decay,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
cfg.training.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=cfg.training.lr_warmup_steps,
|
||||
num_training_steps=cfg.training.offline_steps,
|
||||
)
|
||||
elif policy.name == "tdmpc":
|
||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
elif cfg.policy.name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||
|
||||
optimizer = VQBeTOptimizer(policy, cfg)
|
||||
lr_scheduler = VQBeTScheduler(optimizer, cfg)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
def update_policy(
|
||||
policy,
|
||||
batch,
|
||||
@@ -145,7 +89,7 @@ def update_policy(
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
if isinstance(policy, PolicyWithUpdate):
|
||||
if has_method(policy, "update"):
|
||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
||||
policy.update()
|
||||
|
||||
@@ -161,7 +105,9 @@ def update_policy(
|
||||
return info
|
||||
|
||||
|
||||
def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
|
||||
def log_train_info(
|
||||
logger: Logger, info: dict, step: int, cfg: TrainPipelineConfig, dataset: LeRobotDataset, is_online: bool
|
||||
):
|
||||
loss = info["loss"]
|
||||
grad_norm = info["grad_norm"]
|
||||
lr = info["lr"]
|
||||
@@ -170,7 +116,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
|
||||
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||
num_samples = (step + 1) * cfg.training.batch_size
|
||||
num_samples = (step + 1) * cfg.batch_size
|
||||
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
|
||||
num_episodes = num_samples / avg_samples_per_ep
|
||||
num_epochs = num_samples / dataset.num_frames
|
||||
@@ -207,7 +153,7 @@ def log_eval_info(logger, info, step, cfg, dataset, is_online):
|
||||
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||
num_samples = (step + 1) * cfg.training.batch_size
|
||||
num_samples = (step + 1) * cfg.batch_size
|
||||
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
|
||||
num_episodes = num_samples / avg_samples_per_ep
|
||||
num_epochs = num_samples / dataset.num_frames
|
||||
@@ -234,74 +180,17 @@ def log_eval_info(logger, info, step, cfg, dataset, is_online):
|
||||
logger.log_dict(info, step, mode="eval")
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
cfg.validate()
|
||||
|
||||
init_logging()
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
|
||||
if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
|
||||
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
||||
|
||||
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
|
||||
# to check for any differences between the provided config and the checkpoint's config.
|
||||
if cfg.resume:
|
||||
if not Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
"You have set resume=True, but there is no model checkpoint in "
|
||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"You have set resume=True, indicating that you wish to resume a run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
# Get the configuration file from the last checkpoint.
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
# Check for differences between the checkpoint configuration and provided configuration.
|
||||
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
||||
resolve_delta_timestamps(cfg)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
# Ignore the `resume` and parameters.
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
# Log a warning about differences between the checkpoint configuration and the provided
|
||||
# configuration.
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
"At least one difference was detected between the checkpoint configuration and "
|
||||
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
|
||||
"takes precedence.",
|
||||
)
|
||||
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
|
||||
cfg = checkpoint_cfg
|
||||
cfg.resume = True
|
||||
elif Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If "
|
||||
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
|
||||
)
|
||||
|
||||
if cfg.eval.batch_size > cfg.eval.n_episodes:
|
||||
raise ValueError(
|
||||
"The eval batch size is greater than the number of eval episodes "
|
||||
f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} "
|
||||
f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. "
|
||||
"This might significantly slow down evaluation. To fix this, you should update your command "
|
||||
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), "
|
||||
f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)."
|
||||
)
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
logger = Logger(cfg)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
if cfg.seed is not None:
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
@@ -309,65 +198,58 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_dataset")
|
||||
logging.info("Creating dataset")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
if isinstance(offline_dataset, MultiLeRobotDataset):
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
f"{pformat(offline_dataset.repo_id_to_index , indent=2)}"
|
||||
)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
eval_env = None
|
||||
if cfg.training.eval_freq > 0:
|
||||
logging.info("make_env")
|
||||
eval_env = make_env(cfg)
|
||||
if cfg.eval_freq > 0 and cfg.env is not None:
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
|
||||
|
||||
logging.info("make_policy")
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
cfg=cfg.policy,
|
||||
device=device,
|
||||
ds_meta=offline_dataset.meta,
|
||||
)
|
||||
assert isinstance(policy, nn.Module)
|
||||
# Create optimizer and scheduler
|
||||
# Temporary hack to move optimizer out of policy
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(enabled=cfg.use_amp)
|
||||
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
if cfg.resume:
|
||||
step = logger.load_last_training_state(optimizer, lr_scheduler)
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
log_output_dir(cfg.output_dir)
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.offline.steps=} ({format_big_number(cfg.offline.steps)})")
|
||||
logging.info(f"{cfg.online.steps=}")
|
||||
logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
|
||||
logging.info(f"{offline_dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# Note: this helper will be used in offline and online training loops.
|
||||
def evaluate_and_checkpoint_if_needed(step, is_online):
|
||||
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||
def evaluate_and_checkpoint_if_needed(step: int, is_online: bool):
|
||||
_num_digits = max(6, len(str(cfg.offline.steps + cfg.online.steps)))
|
||||
step_identifier = f"{step:0{_num_digits}d}"
|
||||
|
||||
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
|
||||
if cfg.env is not None and cfg.eval_freq > 0 and step % cfg.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
assert eval_env is not None
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}",
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_identifier}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
@@ -376,28 +258,27 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if cfg.training.save_checkpoint and (
|
||||
step % cfg.training.save_freq == 0
|
||||
or step == cfg.training.offline_steps + cfg.training.online_steps
|
||||
if cfg.save_checkpoint and (
|
||||
step % cfg.save_freq == 0 or step == cfg.offline.steps + cfg.online.steps
|
||||
):
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||
logger.save_checkpoint(
|
||||
step,
|
||||
step_identifier,
|
||||
policy,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
identifier=step_identifier,
|
||||
)
|
||||
logging.info("Resume training")
|
||||
|
||||
# create dataloader for offline training
|
||||
if cfg.training.get("drop_n_last_frames"):
|
||||
if getattr(cfg.policy, "drop_n_last_frames", None):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
offline_dataset.episode_data_index,
|
||||
drop_n_last_frames=cfg.training.drop_n_last_frames,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
@@ -405,8 +286,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
sampler = None
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
offline_dataset,
|
||||
num_workers=cfg.training.num_workers,
|
||||
batch_size=cfg.training.batch_size,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
@@ -416,7 +297,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
policy.train()
|
||||
offline_step = 0
|
||||
for _ in range(step, cfg.training.offline_steps):
|
||||
for _ in range(step, cfg.offline.steps):
|
||||
if offline_step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
|
||||
@@ -431,7 +312,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.training.grad_clip_norm,
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
@@ -439,7 +320,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
train_info["dataloading_s"] = dataloading_s
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
if step % cfg.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
@@ -449,7 +330,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
step += 1
|
||||
offline_step += 1 # noqa: SIM113
|
||||
|
||||
if cfg.training.online_steps == 0:
|
||||
if cfg.online.steps == 0:
|
||||
if eval_env:
|
||||
eval_env.close()
|
||||
logging.info("End of training")
|
||||
@@ -458,8 +339,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
# Online training.
|
||||
|
||||
# Create an env dedicated to online episodes collection from policy rollout.
|
||||
online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
|
||||
resolve_delta_timestamps(cfg)
|
||||
online_env = make_env(cfg.env, n_envs=cfg.online.rollout_batch_size)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, offline_dataset.meta)
|
||||
online_buffer_path = logger.log_dir / "online_buffer"
|
||||
if cfg.resume and not online_buffer_path.exists():
|
||||
# If we are resuming a run, we default to the data shapes and buffer capacity from the saved online
|
||||
@@ -473,31 +354,39 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
online_dataset = OnlineBuffer(
|
||||
online_buffer_path,
|
||||
data_spec={
|
||||
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.input_shapes.items()},
|
||||
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()},
|
||||
**{
|
||||
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
|
||||
for key, ft in policy.config.input_features.items()
|
||||
},
|
||||
**{
|
||||
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
|
||||
for key, ft in policy.config.output_features.items()
|
||||
},
|
||||
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
|
||||
"next.done": {"shape": (), "dtype": np.dtype("?")},
|
||||
"task_index": {"shape": (), "dtype": np.dtype("int64")},
|
||||
# FIXME: 'next.success' is expected by pusht env but not xarm
|
||||
"next.success": {"shape": (), "dtype": np.dtype("?")},
|
||||
},
|
||||
buffer_capacity=cfg.training.online_buffer_capacity,
|
||||
buffer_capacity=cfg.online.buffer_capacity,
|
||||
fps=online_env.unwrapped.metadata["render_fps"],
|
||||
delta_timestamps=cfg.training.delta_timestamps,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
|
||||
# If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this
|
||||
# makes it possible to do online rollouts in parallel with training updates).
|
||||
online_rollout_policy = deepcopy(policy) if cfg.training.do_online_rollout_async else policy
|
||||
online_rollout_policy = deepcopy(policy) if cfg.online.do_rollout_async else policy
|
||||
|
||||
# Create dataloader for online training.
|
||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||
sampler_weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
|
||||
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
|
||||
online_dataset=online_dataset,
|
||||
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
|
||||
# this final observation in the offline datasets, but we might add them in future.
|
||||
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.online.sampling_ratio,
|
||||
)
|
||||
sampler = torch.utils.data.WeightedRandomSampler(
|
||||
sampler_weights,
|
||||
@@ -506,20 +395,22 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
concat_dataset,
|
||||
batch_size=cfg.training.batch_size,
|
||||
num_workers=cfg.training.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
num_workers=cfg.num_workers,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
# Lock and thread pool executor for asynchronous online rollouts. When asynchronous mode is disabled,
|
||||
# these are still used but effectively do nothing.
|
||||
lock = Lock()
|
||||
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
|
||||
# parallelization of rollouts is handled within the job.
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
if cfg.online.do_rollout_async:
|
||||
# Lock and thread pool executor for asynchronous online rollouts.
|
||||
lock = Lock()
|
||||
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
|
||||
# parallelization of rollouts is handled within the job.
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
else:
|
||||
lock = None
|
||||
|
||||
online_step = 0
|
||||
online_rollout_s = 0 # time take to do online rollout
|
||||
@@ -527,10 +418,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
# Time taken waiting for the online buffer to finish being updated. This is relevant when using the async
|
||||
# online rollout option.
|
||||
await_update_online_buffer_s = 0
|
||||
rollout_start_seed = cfg.training.online_env_seed
|
||||
rollout_start_seed = cfg.online.env_seed
|
||||
|
||||
while True:
|
||||
if online_step == cfg.training.online_steps:
|
||||
if online_step == cfg.online.steps:
|
||||
break
|
||||
|
||||
if online_step == 0:
|
||||
@@ -538,25 +429,33 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
def sample_trajectory_and_update_buffer():
|
||||
nonlocal rollout_start_seed
|
||||
with lock:
|
||||
|
||||
with lock if lock is not None else nullcontext():
|
||||
online_rollout_policy.load_state_dict(policy.state_dict())
|
||||
|
||||
online_rollout_policy.eval()
|
||||
start_rollout_time = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
eval_info = eval_policy(
|
||||
online_env,
|
||||
online_rollout_policy,
|
||||
n_episodes=cfg.training.online_rollout_n_episodes,
|
||||
max_episodes_rendered=min(10, cfg.training.online_rollout_n_episodes),
|
||||
n_episodes=cfg.online.rollout_n_episodes,
|
||||
max_episodes_rendered=min(10, cfg.online.rollout_n_episodes),
|
||||
videos_dir=logger.log_dir / "online_rollout_videos",
|
||||
return_episode_data=True,
|
||||
start_seed=(
|
||||
rollout_start_seed := (rollout_start_seed + cfg.training.batch_size) % 1000000
|
||||
),
|
||||
start_seed=(rollout_start_seed := (rollout_start_seed + cfg.batch_size) % 1000000),
|
||||
)
|
||||
online_rollout_s = time.perf_counter() - start_rollout_time
|
||||
|
||||
with lock:
|
||||
if len(offline_dataset.meta.tasks) > 1:
|
||||
raise NotImplementedError("Add support for multi task.")
|
||||
|
||||
# Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
|
||||
total_num_frames = eval_info["episodes"]["index"].shape[0]
|
||||
eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64)
|
||||
|
||||
with lock if lock is not None else nullcontext():
|
||||
start_update_buffer_time = time.perf_counter()
|
||||
online_dataset.add_data(eval_info["episodes"])
|
||||
|
||||
@@ -566,12 +465,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
# Update the sampling weights.
|
||||
sampler.weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
|
||||
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
|
||||
online_dataset=online_dataset,
|
||||
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
|
||||
# this final observation in the offline datasets, but we might add them in future.
|
||||
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.online.sampling_ratio,
|
||||
)
|
||||
sampler.num_frames = len(concat_dataset)
|
||||
|
||||
@@ -579,36 +478,34 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
return online_rollout_s, update_online_buffer_s
|
||||
|
||||
future = executor.submit(sample_trajectory_and_update_buffer)
|
||||
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
|
||||
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
|
||||
if (
|
||||
not cfg.training.do_online_rollout_async
|
||||
or len(online_dataset) <= cfg.training.online_buffer_seed_size
|
||||
):
|
||||
online_rollout_s, update_online_buffer_s = future.result()
|
||||
if lock is None:
|
||||
online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()
|
||||
else:
|
||||
future = executor.submit(sample_trajectory_and_update_buffer)
|
||||
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
|
||||
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
|
||||
if len(online_dataset) <= cfg.online.buffer_seed_size:
|
||||
online_rollout_s, update_online_buffer_s = future.result()
|
||||
|
||||
if len(online_dataset) <= cfg.training.online_buffer_seed_size:
|
||||
logging.info(
|
||||
f"Seeding online buffer: {len(online_dataset)}/{cfg.training.online_buffer_seed_size}"
|
||||
)
|
||||
if len(online_dataset) <= cfg.online.buffer_seed_size:
|
||||
logging.info(f"Seeding online buffer: {len(online_dataset)}/{cfg.online.buffer_seed_size}")
|
||||
continue
|
||||
|
||||
policy.train()
|
||||
for _ in range(cfg.training.online_steps_between_rollouts):
|
||||
with lock:
|
||||
for _ in range(cfg.online.steps_between_rollouts):
|
||||
with lock if lock is not None else nullcontext():
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
||||
train_info = update_policy(
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.training.grad_clip_norm,
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.use_amp,
|
||||
@@ -619,10 +516,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
train_info["online_rollout_s"] = online_rollout_s
|
||||
train_info["update_online_buffer_s"] = update_online_buffer_s
|
||||
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
|
||||
with lock:
|
||||
with lock if lock is not None else nullcontext():
|
||||
train_info["online_buffer_size"] = len(online_dataset)
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
if step % cfg.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
@@ -634,12 +531,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
# If we're doing async rollouts, we should now wait until we've completed them before proceeding
|
||||
# to do the next batch of rollouts.
|
||||
if future.running():
|
||||
if cfg.online.do_rollout_async and future.running():
|
||||
start = time.perf_counter()
|
||||
online_rollout_s, update_online_buffer_s = future.result()
|
||||
await_update_online_buffer_s = time.perf_counter() - start
|
||||
|
||||
if online_step >= cfg.training.online_steps:
|
||||
if online_step >= cfg.online.steps:
|
||||
break
|
||||
|
||||
if eval_env:
|
||||
@@ -647,23 +544,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
logging.info("End of training")
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=config_path)
|
||||
cfg = compose(config_name=config_name)
|
||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
init_logging()
|
||||
train()
|
||||
|
||||
Reference in New Issue
Block a user