Merge remote-tracking branch 'upstream/main' into unify_policy_api

This commit is contained in:
Alexander Soare
2024-04-16 17:30:41 +01:00
4 changed files with 23 additions and 17 deletions

View File

@@ -53,16 +53,10 @@ step = 0
done = False
while not done:
for batch in dataloader:
for k in batch:
batch[k] = batch[k].to(device, non_blocking=True)
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
info = policy.update(batch)
if step % log_freq == 0:
num_samples = (step + 1) * cfg.batch_size
loss = info["loss"]
update_s = info["update_s"]
print(
f"step: {step} samples: {num_samples} loss: {loss:.3f} update_time: {update_s:.3f} (seconds)"
)
print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)")
step += 1
if step >= training_steps:
done = True