Add multithreading for video generation, Speed policy sampling

This commit is contained in:
Cadene
2024-02-24 18:18:39 +00:00
parent 591985c67d
commit aed02dc7c6
4 changed files with 59 additions and 6 deletions

View File

@@ -46,6 +46,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
raise NotImplementedError()
assert torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
set_seed(cfg.seed)
print(colored("Work dir:", "yellow", attrs=["bold"]), out_dir)
@@ -55,9 +56,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
# TODO(rcadene): hack for old pretrained models from fowm
if "fowm" in cfg.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
policy.step = 25000
policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path:
policy.step = 100000
policy.step[0] = 100000
else:
raise NotImplementedError()
policy.load(cfg.pretrained_model_path)