Remove update method from the policy (#99)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Quentin Gallouédec
2024-04-29 12:27:58 +02:00
committed by GitHub
parent 5b4fd8891d
commit 508bd92d03
8 changed files with 84 additions and 122 deletions

View File

@@ -38,6 +38,8 @@ policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, da
policy.train()
policy.to(device)
optimizer = torch.optim.Adam(policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay)
# Create dataloader for offline training.
dataloader = torch.utils.data.DataLoader(
dataset,
@@ -54,9 +56,14 @@ done = False
while not done:
for batch in dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
info = policy.update(batch)
output_dict = policy.forward(batch)
loss = output_dict["loss"]
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % log_freq == 0:
print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)")
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True