Tidy up yaml configs (#121)

This commit is contained in:
Alexander Soare
2024-04-30 16:08:59 +01:00
committed by GitHub
parent e4e739f4f8
commit 9d60dce6f3
21 changed files with 142 additions and 207 deletions

View File

@@ -23,8 +23,8 @@ weights_path = folder / "model.pt"
# Override some config parameters to do with evaluation.
overrides = [
f"policy.pretrained_model_path={weights_path}",
"eval_episodes=10",
"rollout_batch_size=10",
"eval.n_episodes=10",
"eval.batch_size=10",
"device=cuda",
]

View File

@@ -38,15 +38,13 @@ policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, da
policy.train()
policy.to(device)
optimizer = torch.optim.Adam(
policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
# Create dataloader for offline training.
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=cfg.batch_size,
batch_size=64,
shuffle=True,
pin_memory=device != torch.device("cpu"),
drop_last=True,