Solve conflicts + pre-commit run -a
This commit is contained in:
@@ -5,7 +5,6 @@ import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
|
||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
@@ -128,12 +127,8 @@ class DiffusionPolicy(nn.Module):
|
||||
|
||||
out = {
|
||||
"obs": {
|
||||
"image": batch["observation", "image"].to(
|
||||
self.device, non_blocking=True
|
||||
),
|
||||
"agent_pos": batch["observation", "state"].to(
|
||||
self.device, non_blocking=True
|
||||
),
|
||||
"image": batch["observation", "image"].to(self.device, non_blocking=True),
|
||||
"agent_pos": batch["observation", "state"].to(self.device, non_blocking=True),
|
||||
},
|
||||
"action": batch["action"].to(self.device, non_blocking=True),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user