Remove update method from the policy (#99)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
committed by
GitHub
parent
5b4fd8891d
commit
508bd92d03
@@ -38,6 +38,8 @@ policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, da
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
optimizer = torch.optim.Adam(policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay)
|
||||
|
||||
# Create dataloader for offline training.
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
@@ -54,9 +56,14 @@ done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
info = policy.update(batch)
|
||||
output_dict = policy.forward(batch)
|
||||
loss = output_dict["loss"]
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)")
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
|
||||
Reference in New Issue
Block a user