import torch import torch.nn as nn import torch.nn.functional as F from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy class DiffusionPolicy(nn.Module): def __init__( self, shape_meta: dict, noise_scheduler: DDPMScheduler, obs_encoder: MultiImageObsEncoder, horizon, n_action_steps, n_obs_steps, num_inference_steps=None, obs_as_global_cond=True, diffusion_step_embed_dim=256, down_dims=(256, 512, 1024), kernel_size=5, n_groups=8, cond_predict_scale=True, # parameters passed to step **kwargs, ): super().__init__() self.diffusion = DiffusionUnetImagePolicy( shape_meta=shape_meta, noise_scheduler=noise_scheduler, obs_encoder=obs_encoder, horizon=horizon, n_action_steps=n_action_steps, n_obs_steps=n_obs_steps, num_inference_steps=num_inference_steps, obs_as_global_cond=obs_as_global_cond, diffusion_step_embed_dim=diffusion_step_embed_dim, down_dims=down_dims, kernel_size=kernel_size, n_groups=n_groups, cond_predict_scale=cond_predict_scale, # parameters passed to step **kwargs, )