forked from tangger/lerobot
Enable logging all the information returned by the forward methods of policies (#151)
This commit is contained in:
@@ -72,6 +72,7 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||
|
||||
|
||||
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||
"""Returns a dictionary of items for logging."""
|
||||
start_time = time.time()
|
||||
policy.train()
|
||||
output_dict = policy.forward(batch)
|
||||
@@ -99,6 +100,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": optimizer.param_groups[0]["lr"],
|
||||
"update_s": time.time() - start_time,
|
||||
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||
}
|
||||
|
||||
return info
|
||||
@@ -122,7 +124,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
|
||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
def log_train_info(logger, info, step, cfg, dataset, is_offline):
|
||||
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
||||
loss = info["loss"]
|
||||
grad_norm = info["grad_norm"]
|
||||
lr = info["lr"]
|
||||
|
||||
Reference in New Issue
Block a user