"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" TODO(alexander-soare): - Remove reliance on Robomimic for SpatialSoftmax. - Remove reliance on diffusers for DDPMScheduler and LR scheduler. """ import math from collections import deque from typing import Callable import einops import torch import torch.nn.functional as F # noqa: N812 import torchvision from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from huggingface_hub import PyTorchModelHubMixin from robomimic.models.base_nets import SpatialSoftmax from torch import Tensor, nn from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.utils import ( get_device_from_parameters, get_dtype_from_parameters, populate_queues, ) class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): """ Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" (paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy). """ name = "diffusion" def __init__( self, config: DiffusionConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: config: Policy configuration class instance or None, in which case the default instantiation of the configuration class is used. dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected that they will be passed with a call to `load_state_dict` before the policy is used. """ super().__init__() if config is None: config = DiffusionConfig() self.config = config self.normalize_inputs = Normalize( config.input_shapes, config.input_normalization_modes, dataset_stats ) self.normalize_targets = Normalize( config.output_shapes, config.output_normalization_modes, dataset_stats ) self.unnormalize_outputs = Unnormalize( config.output_shapes, config.output_normalization_modes, dataset_stats ) # queues are populated during rollout of the policy, they contain the n latest observations and actions self._queues = None self.diffusion = DiffusionModel(config) def reset(self): """ Clear observation and action queues. Should be called on `env.reset()` """ self._queues = { "observation.image": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps), "action": deque(maxlen=self.config.n_action_steps), } @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. This method handles caching a history of observations and an action trajectory generated by the underlying diffusion model. Here's how it works: - `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is copied `n_obs_steps` times to fill the cache). - The diffusion model generates `horizon` steps worth of actions. - `n_action_steps` worth of actions are actually kept for execution, starting from the current step. Schematically this looks like: ---------------------------------------------------------------------------------------------- (legend: o = n_obs_steps, h = horizon, a = n_action_steps) |timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h| |observation is used | YES | YES | YES | NO | NO | NO | NO | NO | NO | |action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES | |action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO | ---------------------------------------------------------------------------------------------- Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that "horizon" may not the best name to describe what the variable actually means, because this period is actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """ assert "observation.image" in batch assert "observation.state" in batch batch = self.normalize_inputs(batch) self._queues = populate_queues(self._queues, batch) if len(self._queues["action"]) == 0: # stack n latest observations from the queue batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} actions = self.diffusion.generate_actions(batch) # TODO(rcadene): make above methods return output dictionary? actions = self.unnormalize_outputs({"action": actions})["action"] self._queues["action"].extend(actions.transpose(0, 1)) action = self._queues["action"].popleft() return action def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) return {"loss": loss} def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler: """ Factory for noise scheduler instances of the requested type. All kwargs are passed to the scheduler. """ if name == "DDPM": return DDPMScheduler(**kwargs) elif name == "DDIM": return DDIMScheduler(**kwargs) else: raise ValueError(f"Unsupported noise scheduler type {name}") class DiffusionModel(nn.Module): def __init__(self, config: DiffusionConfig): super().__init__() self.config = config self.rgb_encoder = DiffusionRgbEncoder(config) self.unet = DiffusionConditionalUnet1d( config, global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * config.n_obs_steps, ) self.noise_scheduler = _make_noise_scheduler( config.noise_scheduler_type, num_train_timesteps=config.num_train_timesteps, beta_start=config.beta_start, beta_end=config.beta_end, beta_schedule=config.beta_schedule, clip_sample=config.clip_sample, clip_sample_range=config.clip_sample_range, prediction_type=config.prediction_type, ) if config.num_inference_steps is None: self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps else: self.num_inference_steps = config.num_inference_steps # ========= inference ============ def conditional_sample( self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None ) -> Tensor: device = get_device_from_parameters(self) dtype = get_dtype_from_parameters(self) # Sample prior. sample = torch.randn( size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]), dtype=dtype, device=device, generator=generator, ) self.noise_scheduler.set_timesteps(self.num_inference_steps) for t in self.noise_scheduler.timesteps: # Predict model output. model_output = self.unet( sample, torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device), global_cond=global_cond, ) # Compute previous image: x_t -> x_t-1 sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample return sample def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: """ This function expects `batch` to have (at least): { "observation.state": (B, n_obs_steps, state_dim) "observation.image": (B, n_obs_steps, C, H, W) } """ assert set(batch).issuperset({"observation.state", "observation.image"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] assert n_obs_steps == self.config.n_obs_steps # Extract image feature (first combine batch and sequence dims). img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) # Separate batch and sequence dims. img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) # Concatenate state and image features then flatten to (B, global_cond_dim). global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) # run sampling sample = self.conditional_sample(batch_size, global_cond=global_cond) # `horizon` steps worth of actions (from the first observation). actions = sample[..., : self.config.output_shapes["action"][0]] # Extract `n_action_steps` steps worth of actions (from the current observation). start = n_obs_steps - 1 end = start + self.config.n_action_steps actions = actions[:, start:end] return actions def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: """ This function expects `batch` to have (at least): { "observation.state": (B, n_obs_steps, state_dim) "observation.image": (B, n_obs_steps, C, H, W) "action": (B, horizon, action_dim) "action_is_pad": (B, horizon) } """ # Input validation. assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] horizon = batch["action"].shape[1] assert horizon == self.config.horizon assert n_obs_steps == self.config.n_obs_steps # Extract image feature (first combine batch and sequence dims). img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) # Separate batch and sequence dims. img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) # Concatenate state and image features then flatten to (B, global_cond_dim). global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) trajectory = batch["action"] # Forward diffusion. # Sample noise to add to the trajectory. eps = torch.randn(trajectory.shape, device=trajectory.device) # Sample a random noising timestep for each item in the batch. timesteps = torch.randint( low=0, high=self.noise_scheduler.config.num_train_timesteps, size=(trajectory.shape[0],), device=trajectory.device, ).long() # Add noise to the clean trajectories according to the noise magnitude at each timestep. noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps) # Run the denoising network (that might denoise the trajectory, or attempt to predict the noise). pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond) # Compute the loss. # The target is either the original trajectory, or the noise. if self.config.prediction_type == "epsilon": target = eps elif self.config.prediction_type == "sample": target = batch["action"] else: raise ValueError(f"Unsupported prediction type {self.config.prediction_type}") loss = F.mse_loss(pred, target, reduction="none") # Mask loss wherever the action is padded with copies (edges of the dataset trajectory). if self.config.do_mask_loss_for_padding and "action_is_pad" in batch: in_episode_bound = ~batch["action_is_pad"] loss = loss * in_episode_bound.unsqueeze(-1) return loss.mean() class DiffusionRgbEncoder(nn.Module): """Encoder an RGB image into a 1D feature vector. Includes the ability to normalize and crop the image first. """ def __init__(self, config: DiffusionConfig): super().__init__() # Set up optional preprocessing. if config.crop_shape is not None: self.do_crop = True # Always use center crop for eval self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) if config.crop_is_random: self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) else: self.maybe_random_crop = self.center_crop else: self.do_crop = False # Set up backbone. backbone_model = getattr(torchvision.models, config.vision_backbone)( weights=config.pretrained_backbone_weights ) # Note: This assumes that the layer4 feature map is children()[-3] # TODO(alexander-soare): Use a safer alternative. self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) if config.use_group_norm: if config.pretrained_backbone_weights: raise ValueError( "You can't replace BatchNorm in a pretrained model without ruining the weights!" ) self.backbone = _replace_submodules( root_module=self.backbone, predicate=lambda x: isinstance(x, nn.BatchNorm2d), func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), ) # Set up pooling and final layers. # Use a dry run to get the feature map shape. # The dummy input should take the number of image channels from `config.input_shapes` and it should use the # height and width from `config.crop_shape`. dummy_input = torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape)) with torch.inference_mode(): dummy_feature_map = self.backbone(dummy_input) feature_map_shape = tuple(dummy_feature_map.shape[1:]) self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.relu = nn.ReLU() def forward(self, x: Tensor) -> Tensor: """ Args: x: (B, C, H, W) image tensor with pixel values in [0, 1]. Returns: (B, D) image feature. """ # Preprocess: maybe crop (if it was set up in the __init__). if self.do_crop: if self.training: # noqa: SIM108 x = self.maybe_random_crop(x) else: # Always use center crop for eval. x = self.center_crop(x) # Extract backbone feature. x = torch.flatten(self.pool(self.backbone(x)), start_dim=1) # Final linear layer with non-linearity. x = self.relu(self.out(x)) return x def _replace_submodules( root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] ) -> nn.Module: """ Args: root_module: The module for which the submodules need to be replaced predicate: Takes a module as an argument and must return True if the that module is to be replaced. func: Takes a module as an argument and returns a new module to replace it with. Returns: The root module with its submodules replaced. """ if predicate(root_module): return func(root_module) replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] for *parents, k in replace_list: parent_module = root_module if len(parents) > 0: parent_module = root_module.get_submodule(".".join(parents)) if isinstance(parent_module, nn.Sequential): src_module = parent_module[int(k)] else: src_module = getattr(parent_module, k) tgt_module = func(src_module) if isinstance(parent_module, nn.Sequential): parent_module[int(k)] = tgt_module else: setattr(parent_module, k, tgt_module) # verify that all BN are replaced assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) return root_module class DiffusionSinusoidalPosEmb(nn.Module): """1D sinusoidal positional embeddings as in Attention is All You Need.""" def __init__(self, dim: int): super().__init__() self.dim = dim def forward(self, x: Tensor) -> Tensor: device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x.unsqueeze(-1) * emb.unsqueeze(0) emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb class DiffusionConv1dBlock(nn.Module): """Conv1d --> GroupNorm --> Mish""" def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): super().__init__() self.block = nn.Sequential( nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), nn.GroupNorm(n_groups, out_channels), nn.Mish(), ) def forward(self, x): return self.block(x) class DiffusionConditionalUnet1d(nn.Module): """A 1D convolutional UNet with FiLM modulation for conditioning. Note: this removes local conditioning as compared to the original diffusion policy code. """ def __init__(self, config: DiffusionConfig, global_cond_dim: int): super().__init__() self.config = config # Encoder for the diffusion timestep. self.diffusion_step_encoder = nn.Sequential( DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim), nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4), nn.Mish(), nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim), ) # The FiLM conditioning dimension. cond_dim = config.diffusion_step_embed_dim + global_cond_dim # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we # just reverse these. in_out = [(config.output_shapes["action"][0], config.down_dims[0])] + list( zip(config.down_dims[:-1], config.down_dims[1:], strict=True) ) # Unet encoder. common_res_block_kwargs = { "cond_dim": cond_dim, "kernel_size": config.kernel_size, "n_groups": config.n_groups, "use_film_scale_modulation": config.use_film_scale_modulation, } self.down_modules = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (len(in_out) - 1) self.down_modules.append( nn.ModuleList( [ DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs), DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs), # Downsample as long as it is not the last block. nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(), ] ) ) # Processing in the middle of the auto-encoder. self.mid_modules = nn.ModuleList( [ DiffusionConditionalResidualBlock1d( config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs ), DiffusionConditionalResidualBlock1d( config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs ), ] ) # Unet decoder. self.up_modules = nn.ModuleList([]) for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])): is_last = ind >= (len(in_out) - 1) self.up_modules.append( nn.ModuleList( [ # dim_in * 2, because it takes the encoder's skip connection as well DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs), DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs), # Upsample as long as it is not the last block. nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(), ] ) ) self.final_conv = nn.Sequential( DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size), nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1), ) def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor: """ Args: x: (B, T, input_dim) tensor for input to the Unet. timestep: (B,) tensor of (timestep_we_are_denoising_from - 1). global_cond: (B, global_cond_dim) output: (B, T, input_dim) Returns: (B, T, input_dim) diffusion model prediction. """ # For 1D convolutions we'll need feature dimension first. x = einops.rearrange(x, "b t d -> b d t") timesteps_embed = self.diffusion_step_encoder(timestep) # If there is a global conditioning feature, concatenate it to the timestep embedding. if global_cond is not None: global_feature = torch.cat([timesteps_embed, global_cond], axis=-1) else: global_feature = timesteps_embed # Run encoder, keeping track of skip features to pass to the decoder. encoder_skip_features: list[Tensor] = [] for resnet, resnet2, downsample in self.down_modules: x = resnet(x, global_feature) x = resnet2(x, global_feature) encoder_skip_features.append(x) x = downsample(x) for mid_module in self.mid_modules: x = mid_module(x, global_feature) # Run decoder, using the skip features from the encoder. for resnet, resnet2, upsample in self.up_modules: x = torch.cat((x, encoder_skip_features.pop()), dim=1) x = resnet(x, global_feature) x = resnet2(x, global_feature) x = upsample(x) x = self.final_conv(x) x = einops.rearrange(x, "b d t -> b t d") return x class DiffusionConditionalResidualBlock1d(nn.Module): """ResNet style 1D convolutional block with FiLM modulation for conditioning.""" def __init__( self, in_channels: int, out_channels: int, cond_dim: int, kernel_size: int = 3, n_groups: int = 8, # Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning # FiLM just modulates bias). use_film_scale_modulation: bool = False, ): super().__init__() self.use_film_scale_modulation = use_film_scale_modulation self.out_channels = out_channels self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) # A final convolution for dimension matching the residual (if needed). self.residual_conv = ( nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() ) def forward(self, x: Tensor, cond: Tensor) -> Tensor: """ Args: x: (B, in_channels, T) cond: (B, cond_dim) Returns: (B, out_channels, T) """ out = self.conv1(x) # Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1). cond_embed = self.cond_encoder(cond).unsqueeze(-1) if self.use_film_scale_modulation: # Treat the embedding as a list of scales and biases. scale = cond_embed[:, : self.out_channels] bias = cond_embed[:, self.out_channels :] out = scale * out + bias else: # Treat the embedding as biases. out = out + cond_embed out = self.conv2(out) out = out + self.residual_conv(x) return out