Fix missing policy.to(device) in policy factory (#126)

This commit is contained in:
Alexander Soare
2024-05-01 17:26:58 +01:00
committed by GitHub
parent d1855a202a
commit c1668924ab

View File

@@ -65,8 +65,9 @@ def make_policy(
if pretrained_policy_name_or_path is None:
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
policy = policy_cls(policy_cfg, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device))
else:
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path)
policy.to(get_safe_torch_device(hydra_cfg.device))
return policy