Fix missing policy.to(device) in policy factory (#126)
This commit is contained in:
@@ -65,8 +65,9 @@ def make_policy(
|
||||
if pretrained_policy_name_or_path is None:
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||
policy = policy_cls(policy_cfg, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
else:
|
||||
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path)
|
||||
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
|
||||
return policy
|
||||
|
||||
Reference in New Issue
Block a user