from typing import Dict import torch import torch.nn.functional as F # noqa: N812 from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from einops import reduce from diffusion_policy.common.pytorch_util import dict_apply from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder from diffusion_policy.policy.base_image_policy import BaseImagePolicy class DiffusionUnetImagePolicy(BaseImagePolicy): def __init__( self, shape_meta: dict, noise_scheduler: DDPMScheduler, obs_encoder: MultiImageObsEncoder, horizon, n_action_steps, n_obs_steps, num_inference_steps=None, obs_as_global_cond=True, diffusion_step_embed_dim=256, down_dims=(256, 512, 1024), kernel_size=5, n_groups=8, cond_predict_scale=True, # parameters passed to step **kwargs, ): super().__init__() # parse shapes action_shape = shape_meta["action"]["shape"] assert len(action_shape) == 1 action_dim = action_shape[0] # get feature dim obs_feature_dim = obs_encoder.output_shape()[0] # create diffusion model input_dim = action_dim + obs_feature_dim global_cond_dim = None if obs_as_global_cond: input_dim = action_dim global_cond_dim = obs_feature_dim * n_obs_steps model = ConditionalUnet1D( input_dim=input_dim, local_cond_dim=None, global_cond_dim=global_cond_dim, diffusion_step_embed_dim=diffusion_step_embed_dim, down_dims=down_dims, kernel_size=kernel_size, n_groups=n_groups, cond_predict_scale=cond_predict_scale, ) self.obs_encoder = obs_encoder self.model = model self.noise_scheduler = noise_scheduler self.mask_generator = LowdimMaskGenerator( action_dim=action_dim, obs_dim=0 if obs_as_global_cond else obs_feature_dim, max_n_obs_steps=n_obs_steps, fix_obs_steps=True, action_visible=False, ) self.horizon = horizon self.obs_feature_dim = obs_feature_dim self.action_dim = action_dim self.n_action_steps = n_action_steps self.n_obs_steps = n_obs_steps self.obs_as_global_cond = obs_as_global_cond self.kwargs = kwargs if num_inference_steps is None: num_inference_steps = noise_scheduler.config.num_train_timesteps self.num_inference_steps = num_inference_steps # ========= inference ============ def conditional_sample( self, condition_data, condition_mask, local_cond=None, global_cond=None, generator=None, # keyword arguments to scheduler.step **kwargs, ): model = self.model scheduler = self.noise_scheduler trajectory = torch.randn( size=condition_data.shape, dtype=condition_data.dtype, device=condition_data.device, generator=generator, ) # set step values scheduler.set_timesteps(self.num_inference_steps) for t in scheduler.timesteps: # 1. apply conditioning trajectory[condition_mask] = condition_data[condition_mask] # 2. predict model output model_output = model(trajectory, t, local_cond=local_cond, global_cond=global_cond) # 3. compute previous image: x_t -> x_t-1 trajectory = scheduler.step( model_output, t, trajectory, generator=generator, # **kwargs # TODO(rcadene): in diffusion_policy, expected to be {} ).prev_sample # finally make sure conditioning is enforced trajectory[condition_mask] = condition_data[condition_mask] return trajectory def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ obs_dict: must include "obs" key result: must include "action" key """ assert "past_action" not in obs_dict # not implemented yet nobs = obs_dict value = next(iter(nobs.values())) bsize, n_obs_steps = value.shape[:2] horizon = self.horizon action_dim = self.action_dim obs_dim = self.obs_feature_dim assert self.n_obs_steps == n_obs_steps # build input device = self.device dtype = self.dtype # handle different ways of passing observation local_cond = None global_cond = None if self.obs_as_global_cond: # condition through global feature this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:])) nobs_features = self.obs_encoder(this_nobs) # reshape back to B, Do global_cond = nobs_features.reshape(bsize, -1) # empty data for action cond_data = torch.zeros(size=(bsize, horizon, action_dim), device=device, dtype=dtype) cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) else: # condition through impainting this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:])) nobs_features = self.obs_encoder(this_nobs) # reshape back to B, T, Do nobs_features = nobs_features.reshape(bsize, n_obs_steps, -1) cond_data = torch.zeros(size=(bsize, horizon, action_dim + obs_dim), device=device, dtype=dtype) cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) cond_data[:, :n_obs_steps, action_dim:] = nobs_features cond_mask[:, :n_obs_steps, action_dim:] = True # run sampling nsample = self.conditional_sample( cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond, **self.kwargs ) action_pred = nsample[..., :action_dim] # get action start = n_obs_steps - 1 end = start + self.n_action_steps action = action_pred[:, start:end] result = {"action": action, "action_pred": action_pred} return result def compute_loss(self, batch): assert "valid_mask" not in batch nobs = batch["obs"] nactions = batch["action"] batch_size = nactions.shape[0] horizon = nactions.shape[1] # handle different ways of passing observation local_cond = None global_cond = None trajectory = nactions cond_data = trajectory if self.obs_as_global_cond: # reshape B, T, ... to B*T this_nobs = dict_apply(nobs, lambda x: x[:, : self.n_obs_steps, ...].reshape(-1, *x.shape[2:])) nobs_features = self.obs_encoder(this_nobs) # reshape back to B, Do global_cond = nobs_features.reshape(batch_size, -1) else: # reshape B, T, ... to B*T this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:])) nobs_features = self.obs_encoder(this_nobs) # reshape back to B, T, Do nobs_features = nobs_features.reshape(batch_size, horizon, -1) cond_data = torch.cat([nactions, nobs_features], dim=-1) trajectory = cond_data.detach() # generate impainting mask condition_mask = self.mask_generator(trajectory.shape) # Sample noise that we'll add to the images noise = torch.randn(trajectory.shape, device=trajectory.device) bsz = trajectory.shape[0] # Sample a random timestep for each image timesteps = torch.randint( 0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=trajectory.device ).long() # Add noise to the clean images according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps) # compute loss mask loss_mask = ~condition_mask # apply conditioning noisy_trajectory[condition_mask] = cond_data[condition_mask] # Predict the noise residual pred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond) pred_type = self.noise_scheduler.config.prediction_type if pred_type == "epsilon": target = noise elif pred_type == "sample": target = trajectory else: raise ValueError(f"Unsupported prediction type {pred_type}") loss = F.mse_loss(pred, target, reduction="none") loss = loss * loss_mask.type(loss.dtype) loss = reduce(loss, "b ... -> b (...)", "mean") loss = loss.mean() return loss