47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
|
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
|
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
|
|
|
|
|
class DiffusionPolicy(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
shape_meta: dict,
|
|
noise_scheduler: DDPMScheduler,
|
|
obs_encoder: MultiImageObsEncoder,
|
|
horizon,
|
|
n_action_steps,
|
|
n_obs_steps,
|
|
num_inference_steps=None,
|
|
obs_as_global_cond=True,
|
|
diffusion_step_embed_dim=256,
|
|
down_dims=(256, 512, 1024),
|
|
kernel_size=5,
|
|
n_groups=8,
|
|
cond_predict_scale=True,
|
|
# parameters passed to step
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.diffusion = DiffusionUnetImagePolicy(
|
|
shape_meta=shape_meta,
|
|
noise_scheduler=noise_scheduler,
|
|
obs_encoder=obs_encoder,
|
|
horizon=horizon,
|
|
n_action_steps=n_action_steps,
|
|
n_obs_steps=n_obs_steps,
|
|
num_inference_steps=num_inference_steps,
|
|
obs_as_global_cond=obs_as_global_cond,
|
|
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
|
down_dims=down_dims,
|
|
kernel_size=kernel_size,
|
|
n_groups=n_groups,
|
|
cond_predict_scale=cond_predict_scale,
|
|
# parameters passed to step
|
|
**kwargs,
|
|
)
|