Add get_safe_torch_device in policies

This commit is contained in:
Simon Alibert
2024-03-20 18:38:55 +01:00
parent ec536ef0fa
commit 4631d36c05
6 changed files with 39 additions and 18 deletions

View File

@@ -8,6 +8,7 @@ from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
from lerobot.common.utils import get_safe_torch_device
class DiffusionPolicy(AbstractPolicy):
@@ -62,9 +63,8 @@ class DiffusionPolicy(AbstractPolicy):
**kwargs,
)
self.device = torch.device(cfg_device)
if torch.cuda.is_available() and cfg_device == "cuda":
self.diffusion.cuda()
self.device = get_safe_torch_device(cfg_device)
self.diffusion.to(self.device)
self.ema = None
if self.cfg.use_ema: