Add num_workers and batch_size to default.yaml
This commit is contained in:
@@ -26,6 +26,8 @@ training:
|
|||||||
save_freq: ???
|
save_freq: ???
|
||||||
log_freq: 250
|
log_freq: 250
|
||||||
save_model: true
|
save_model: true
|
||||||
|
num_workers: 4
|
||||||
|
batch_size: ???
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 1
|
n_episodes: 1
|
||||||
|
|||||||
@@ -386,7 +386,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||||||
# create dataloader for offline training
|
# create dataloader for offline training
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
offline_dataset,
|
offline_dataset,
|
||||||
num_workers=4,
|
num_workers=cfg.training.num_workers,
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=cfg.device != "cpu",
|
pin_memory=cfg.device != "cpu",
|
||||||
|
|||||||
Reference in New Issue
Block a user