revision
This commit is contained in:
@@ -152,7 +152,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg)
|
||||
policy.save("act.pt")
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
@@ -198,7 +197,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
is_offline = True
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
num_workers=4,
|
||||
batch_size=cfg.policy.batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
|
||||
Reference in New Issue
Block a user