forked from tangger/lerobot
Merge remote-tracking branch 'upstream/main' into fix_pusht_diffusion
This commit is contained in:
@@ -9,6 +9,7 @@ from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
|
||||
from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
|
||||
class DiffusionPolicy(AbstractPolicy):
|
||||
@@ -66,9 +67,8 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.device = torch.device(cfg_device)
|
||||
if torch.cuda.is_available() and cfg_device == "cuda":
|
||||
self.diffusion.cuda()
|
||||
self.device = get_safe_torch_device(cfg_device)
|
||||
self.diffusion.to(self.device)
|
||||
|
||||
self.ema_diffusion = None
|
||||
self.ema = None
|
||||
|
||||
Reference in New Issue
Block a user