Enable logging all the information returned by the forward methods of policies (#151)

This commit is contained in:
Alexander Soare
2024-05-10 07:45:32 +01:00
committed by GitHub
parent b187942db4
commit 1249aee3ac
5 changed files with 12 additions and 4 deletions

View File

@@ -72,6 +72,7 @@ def make_optimizer_and_scheduler(cfg, policy):
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"""Returns a dictionary of items for logging."""
start_time = time.time()
policy.train()
output_dict = policy.forward(batch)
@@ -99,6 +100,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
"grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"],
"update_s": time.time() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"},
}
return info
@@ -122,7 +124,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
train(cfg, out_dir=out_dir, job_name=job_name)
def log_train_info(logger, info, step, cfg, dataset, is_offline):
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
loss = info["loss"]
grad_norm = info["grad_norm"]
lr = info["lr"]