This commit is contained in:
Alexander Soare
2024-04-08 14:44:10 +01:00
parent 0a721f3d94
commit 86365adf9f
4 changed files with 20 additions and 16 deletions

View File

@@ -152,7 +152,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info("make_policy")
policy = make_policy(cfg)
policy.save("act.pt")
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
@@ -198,7 +197,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
is_offline = True
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
num_workers=4,
batch_size=cfg.policy.batch_size,
shuffle=True,
pin_memory=cfg.device != "cpu",