Fix
This commit is contained in:
@@ -4,7 +4,6 @@ import time
|
|||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||||
|
|
||||||
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||||
@@ -15,6 +14,7 @@ class DiffusionPolicy(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cfg,
|
cfg,
|
||||||
|
cfg_device,
|
||||||
cfg_noise_scheduler,
|
cfg_noise_scheduler,
|
||||||
cfg_rgb_model,
|
cfg_rgb_model,
|
||||||
cfg_obs_encoder,
|
cfg_obs_encoder,
|
||||||
@@ -62,8 +62,9 @@ class DiffusionPolicy(nn.Module):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.device = torch.device("cuda")
|
self.device = torch.device(cfg_device)
|
||||||
self.diffusion.cuda()
|
if torch.cuda.is_available() and cfg_device == "cuda":
|
||||||
|
self.diffusion.cuda()
|
||||||
|
|
||||||
self.ema = None
|
self.ema = None
|
||||||
if self.cfg.use_ema:
|
if self.cfg.use_ema:
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ def make_policy(cfg):
|
|||||||
|
|
||||||
policy = DiffusionPolicy(
|
policy = DiffusionPolicy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
|
cfg_device=cfg.device,
|
||||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||||
cfg_rgb_model=cfg.rgb_model,
|
cfg_rgb_model=cfg.rgb_model,
|
||||||
cfg_obs_encoder=cfg.obs_encoder,
|
cfg_obs_encoder=cfg.obs_encoder,
|
||||||
|
|||||||
Reference in New Issue
Block a user