[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by AdilZouitine
parent 761a2dbcb3
commit 8e6d5f504c
97 changed files with 1596 additions and 492 deletions

View File

@@ -223,14 +223,18 @@ def train(cfg: TrainPipelineConfig):
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
step, optimizer, lr_scheduler = load_training_state(
cfg.checkpoint_path, optimizer, lr_scheduler
)
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info(
colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}"
)
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
@@ -273,7 +277,11 @@ def train(cfg: TrainPipelineConfig):
}
train_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
train_metrics,
initial_step=step,
)
logging.info("Start offline training on a fixed dataset")
@@ -327,7 +335,9 @@ def train(cfg: TrainPipelineConfig):
logging.info(f"Eval policy at step {step}")
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
torch.autocast(device_type=device.type)
if cfg.policy.use_amp
else nullcontext(),
):
eval_info = eval_policy(
eval_env,
@@ -344,7 +354,11 @@ def train(cfg: TrainPipelineConfig):
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
eval_metrics,
initial_step=step,
)
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")