From cc2f6e74047bd65db0f9705fa602636b625bc28c Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 9 Jul 2024 12:35:50 +0100 Subject: [PATCH] Train diffusion pusht_keypoints (#307) Co-authored-by: Remi --- lerobot/common/envs/utils.py | 38 +++--- .../diffusion/configuration_diffusion.py | 44 ++++--- .../policies/diffusion/modeling_diffusion.py | 70 +++++++---- .../policy/diffusion_pusht_keypoints.yaml | 110 ++++++++++++++++++ 4 files changed, 206 insertions(+), 56 deletions(-) create mode 100644 lerobot/configs/policy/diffusion_pusht_keypoints.yaml diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 8fce036..32da006 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -28,31 +28,35 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten """ # map to expected inputs for the policy return_observations = {} + if "pixels" in observations: + if isinstance(observations["pixels"], dict): + imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} + else: + imgs = {"observation.image": observations["pixels"]} - if isinstance(observations["pixels"], dict): - imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} - else: - imgs = {"observation.image": observations["pixels"]} + for imgkey, img in imgs.items(): + img = torch.from_numpy(img) - for imgkey, img in imgs.items(): - img = torch.from_numpy(img) + # sanity check that images are channel last + _, h, w, c = img.shape + assert c < h and c < w, f"expect channel first images, but instead {img.shape}" - # sanity check that images are channel last - _, h, w, c = img.shape - assert c < h and c < w, f"expect channel first images, but instead {img.shape}" + # sanity check that images are uint8 + assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" - # sanity check that images are uint8 - assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" + # convert to channel first of type float32 in range [0,1] + img = einops.rearrange(img, "b h w c -> b c h w").contiguous() + img = img.type(torch.float32) + img /= 255 - # convert to channel first of type float32 in range [0,1] - img = einops.rearrange(img, "b h w c -> b c h w").contiguous() - img = img.type(torch.float32) - img /= 255 + return_observations[imgkey] = img - return_observations[imgkey] = img + if "environment_state" in observations: + return_observations["observation.environment_state"] = torch.from_numpy( + observations["environment_state"] + ).float() # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing # requirement for "agent_pos" return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() - return return_observations diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 2b7923a..1e1f9d2 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -28,7 +28,10 @@ class DiffusionConfig: Notes on the inputs and outputs: - "observation.state" is required as an input key. - - At least one key starting with "observation.image is required as an input. + - Either: + - At least one key starting with "observation.image is required as an input. + AND/OR + - The key "observation.environment_state" is required as input. - If there are multiple keys beginning with "observation.image" they are treated as multiple camera views. Right now we only support all images having the same shape. - "action" is required as an output key. @@ -155,26 +158,33 @@ class DiffusionConfig: raise ValueError( f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." ) + image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} - if self.crop_shape is not None: + + if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes: + raise ValueError("You must provide at least one image or the environment state among the inputs.") + + if len(image_keys) > 0: + if self.crop_shape is not None: + for image_key in image_keys: + if ( + self.crop_shape[0] > self.input_shapes[image_key][1] + or self.crop_shape[1] > self.input_shapes[image_key][2] + ): + raise ValueError( + f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} " + f"for `crop_shape` and {self.input_shapes[image_key]} for " + "`input_shapes[{image_key}]`." + ) + # Check that all input images have the same shape. + first_image_key = next(iter(image_keys)) for image_key in image_keys: - if ( - self.crop_shape[0] > self.input_shapes[image_key][1] - or self.crop_shape[1] > self.input_shapes[image_key][2] - ): + if self.input_shapes[image_key] != self.input_shapes[first_image_key]: raise ValueError( - f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} " - f"for `crop_shape` and {self.input_shapes[image_key]} for " - "`input_shapes[{image_key}]`." + f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we " + "expect all image shapes to match." ) - # Check that all input images have the same shape. - first_image_key = next(iter(image_keys)) - for image_key in image_keys: - if self.input_shapes[image_key] != self.input_shapes[first_image_key]: - raise ValueError( - f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we " - "expect all image shapes to match." - ) + supported_prediction_types = ["epsilon", "sample"] if self.prediction_type not in supported_prediction_types: raise ValueError( diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 3356539..ec4039c 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -83,16 +83,20 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): self.diffusion = DiffusionModel(config) self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + self.use_env_state = "observation.environment_state" in config.input_shapes self.reset() def reset(self): """Clear observation and action queues. Should be called on `env.reset()`""" self._queues = { - "observation.images": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps), "action": deque(maxlen=self.config.n_action_steps), } + if len(self.expected_image_keys) > 0: + self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) + if self.use_env_state: + self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: @@ -117,7 +121,8 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """ batch = self.normalize_inputs(batch) - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + if len(self.expected_image_keys) > 0: + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) # Note: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) @@ -137,7 +142,8 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + if len(self.expected_image_keys) > 0: + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) return {"loss": loss} @@ -161,15 +167,20 @@ class DiffusionModel(nn.Module): super().__init__() self.config = config - self.rgb_encoder = DiffusionRgbEncoder(config) + # Build observation encoders (depending on which observations are provided). + global_cond_dim = config.input_shapes["observation.state"][0] num_images = len([k for k in config.input_shapes if k.startswith("observation.image")]) - self.unet = DiffusionConditionalUnet1d( - config, - global_cond_dim=( - config.input_shapes["observation.state"][0] + self.rgb_encoder.feature_dim * num_images - ) - * config.n_obs_steps, - ) + self._use_images = False + self._use_env_state = False + if num_images > 0: + self._use_images = True + self.rgb_encoder = DiffusionRgbEncoder(config) + global_cond_dim += self.rgb_encoder.feature_dim * num_images + if "observation.environment_state" in config.input_shapes: + self._use_env_state = True + global_cond_dim += config.input_shapes["observation.environment_state"][0] + + self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps) self.noise_scheduler = _make_noise_scheduler( config.noise_scheduler_type, @@ -219,24 +230,34 @@ class DiffusionModel(nn.Module): def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor: """Encode image features and concatenate them all together along with the state vector.""" batch_size, n_obs_steps = batch["observation.state"].shape[:2] + global_cond_feats = [batch["observation.state"]] # Extract image feature (first combine batch, sequence, and camera index dims). - img_features = self.rgb_encoder( - einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") - ) - # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the feature - # dim (effectively concatenating the camera features). - img_features = einops.rearrange( - img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps - ) - # Concatenate state and image features then flatten to (B, global_cond_dim). - return torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) + if self._use_images: + img_features = self.rgb_encoder( + einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") + ) + # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the + # feature dim (effectively concatenating the camera features). + img_features = einops.rearrange( + img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps + ) + global_cond_feats.append(img_features) + + if self._use_env_state: + global_cond_feats.append(batch["observation.environment_state"]) + + # Concatenate features then flatten to (B, global_cond_dim). + return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1) def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: """ This function expects `batch` to have: { "observation.state": (B, n_obs_steps, state_dim) + "observation.images": (B, n_obs_steps, num_cameras, C, H, W) + AND/OR + "observation.environment_state": (B, environment_dim) } """ batch_size, n_obs_steps = batch["observation.state"].shape[:2] @@ -260,13 +281,18 @@ class DiffusionModel(nn.Module): This function expects `batch` to have (at least): { "observation.state": (B, n_obs_steps, state_dim) + "observation.images": (B, n_obs_steps, num_cameras, C, H, W) + AND/OR + "observation.environment_state": (B, environment_dim) + "action": (B, horizon, action_dim) "action_is_pad": (B, horizon) } """ # Input validation. - assert set(batch).issuperset({"observation.state", "observation.images", "action", "action_is_pad"}) + assert set(batch).issuperset({"observation.state", "action", "action_is_pad"}) + assert "observation.images" in batch or "observation.environment_state" in batch n_obs_steps = batch["observation.state"].shape[1] horizon = batch["action"].shape[1] assert horizon == self.config.horizon diff --git a/lerobot/configs/policy/diffusion_pusht_keypoints.yaml b/lerobot/configs/policy/diffusion_pusht_keypoints.yaml new file mode 100644 index 0000000..a5fe6cf --- /dev/null +++ b/lerobot/configs/policy/diffusion_pusht_keypoints.yaml @@ -0,0 +1,110 @@ +# @package _global_ + +# Defaults for training for the pusht_keypoints dataset. + +# They keypoints are on the vertices of the rectangles that make up the PushT as documented in the PushT +# environment: +# https://github.com/huggingface/gym-pusht/blob/5e2489be9ff99ed9cd47b6c653dda3b7aa844d24/gym_pusht/envs/pusht.py#L522-L534 +# For completeness, the diagram is copied here: +# 0───────────1 +# │ │ +# 3───4───5───2 +# │ │ +# │ │ +# │ │ +# │ │ +# 7───6 + + +# Note: The original work trains keypoints-only with conditioning via inpainting. Here, we encode the +# observation along with the agent position and use the encoding as global conditioning for the denoising +# U-Net. + +# Note: We do not track EMA model weights as we discovered it does not improve the results. See +# https://github.com/huggingface/lerobot/pull/134 for more details. + +seed: 100000 +dataset_repo_id: lerobot/pusht_keypoints + +training: + offline_steps: 200000 + online_steps: 0 + eval_freq: 5000 + save_freq: 5000 + log_freq: 250 + save_checkpoint: true + + batch_size: 64 + grad_clip_norm: 10 + lr: 1.0e-4 + lr_scheduler: cosine + lr_warmup_steps: 500 + adam_betas: [0.95, 0.999] + adam_eps: 1.0e-8 + adam_weight_decay: 1.0e-6 + online_steps_between_rollouts: 1 + + delta_timestamps: + observation.environment_state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" + observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" + action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]" + + # The original implementation doesn't sample frames for the last 7 steps, + # which avoids excessive padding and leads to improved training results. + drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1 + +eval: + n_episodes: 50 + batch_size: 50 + +policy: + name: diffusion + + # Input / output structure. + n_obs_steps: 2 + horizon: 16 + n_action_steps: 8 + + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.environment_state: [16] + observation.state: ["${env.state_dim}"] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + input_normalization_modes: + observation.environment_state: min_max + observation.state: min_max + output_normalization_modes: + action: min_max + + # Architecture / modeling. + # Vision backbone. + vision_backbone: resnet18 + crop_shape: [84, 84] + crop_is_random: True + pretrained_backbone_weights: null + use_group_norm: True + spatial_softmax_num_keypoints: 32 + # Unet. + down_dims: [256, 512, 1024] + kernel_size: 5 + n_groups: 8 + diffusion_step_embed_dim: 128 + use_film_scale_modulation: True + # Noise scheduler. + noise_scheduler_type: DDIM + num_train_timesteps: 100 + beta_schedule: squaredcos_cap_v2 + beta_start: 0.0001 + beta_end: 0.02 + prediction_type: epsilon # epsilon / sample + clip_sample: True + clip_sample_range: 1.0 + + # Inference + num_inference_steps: 10 # if not provided, defaults to `num_train_timesteps` + + # Loss computation + do_mask_loss_for_padding: false