forked from tangger/lerobot
Tidy up yaml configs (#121)
This commit is contained in:
@@ -23,8 +23,8 @@ weights_path = folder / "model.pt"
|
||||
# Override some config parameters to do with evaluation.
|
||||
overrides = [
|
||||
f"policy.pretrained_model_path={weights_path}",
|
||||
"eval_episodes=10",
|
||||
"rollout_batch_size=10",
|
||||
"eval.n_episodes=10",
|
||||
"eval.batch_size=10",
|
||||
"device=cuda",
|
||||
]
|
||||
|
||||
|
||||
@@ -38,15 +38,13 @@ policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, da
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
optimizer = torch.optim.Adam(
|
||||
policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
|
||||
)
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
||||
|
||||
# Create dataloader for offline training.
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=cfg.batch_size,
|
||||
batch_size=64,
|
||||
shuffle=True,
|
||||
pin_memory=device != torch.device("cpu"),
|
||||
drop_last=True,
|
||||
|
||||
Reference in New Issue
Block a user