refactor(config): Move device & amp args to PreTrainedConfig (#812)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Steven Palma
2025-03-06 17:59:28 +01:00
committed by GitHub
parent 10706ed753
commit 5e9473806c
19 changed files with 62 additions and 136 deletions

View File

@@ -120,7 +120,7 @@ def train(cfg: TrainPipelineConfig):
set_seed(cfg.seed)
# Check device is available
device = get_safe_torch_device(cfg.device, log=True)
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -138,13 +138,12 @@ def train(cfg: TrainPipelineConfig):
logging.info("Creating policy")
policy = make_policy(
cfg=cfg.policy,
device=device,
ds_meta=dataset.meta,
)
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
@@ -218,7 +217,7 @@ def train(cfg: TrainPipelineConfig):
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
use_amp=cfg.policy.use_amp,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -249,7 +248,10 @@ def train(cfg: TrainPipelineConfig):
if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
eval_info = eval_policy(
eval_env,
policy,