From b47c07fbeb7336f0511c43e83f9a9e9ac6d0f591 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 28 May 2024 21:37:46 +0100 Subject: [PATCH] cherry pick me --- .../common/policies/act/configuration_act.py | 26 ++++--- .../diffusion/configuration_diffusion.py | 60 +++++++++------- .../policies/diffusion/modeling_diffusion.py | 71 ++++++++++--------- .../policies/tdmpc/configuration_tdmpc.py | 9 +++ 4 files changed, 100 insertions(+), 66 deletions(-) diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 80759f66..478b12d4 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -25,6 +25,14 @@ class ACTConfig: The parameters you will most likely need to change are the ones which depend on the environment / sensors. Those are: `input_shapes` and 'output_shapes`. + Notes on the inputs and outputs: + - "observation.state" is required as an input key. + - At least one key starting with "observation.image is required as an input. + - If there are multiple keys beginning with "observation.image" they are treated as multiple camera + views. + Right now we only support all images having the same shape. + - "action" is required as an output key. + Args: n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). @@ -33,15 +41,15 @@ class ACTConfig: This should be no greater than the chunk size. For example, if the chunk size size 100, you may set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the environment, and throws the other 50 out. - input_shapes: A dictionary defining the shapes of the input data for the policy. - The key represents the input data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "observation.images.top" refers to an input from the - "top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. - Importantly, shapes doesn't include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. - The key represents the output data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "action" refers to an output shape of [14], indicating - 14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), and the value specifies the normalization mode to apply. The two available modes are "mean_std" which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 632f6cd6..49dc1bfb 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -26,21 +26,29 @@ class DiffusionConfig: The parameters you will most likely need to change are the ones which depend on the environment / sensors. Those are: `input_shapes` and `output_shapes`. + Notes on the inputs and outputs: + - "observation.state" is required as an input key. + - At least one key starting with "observation.image is required as an input. + - If there are multiple keys beginning with "observation.image" they are treated as multiple camera + views. + Right now we only support all images having the same shape. + - "action" is required as an output key. + Args: n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. n_action_steps: The number of action steps to run in the environment for one invocation of the policy. See `DiffusionPolicy.select_action` for more details. - input_shapes: A dictionary defining the shapes of the input data for the policy. - The key represents the input data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "observation.image" refers to an input from - a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. - Importantly, shapes doesnt include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. - The key represents the output data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "action" refers to an output shape of [14], indicating - 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), and the value specifies the normalization mode to apply. The two available modes are "mean_std" which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a @@ -148,22 +156,26 @@ class DiffusionConfig: raise ValueError( f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." ) - # There should only be one image key. image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} - if len(image_keys) != 1: - raise ValueError( - f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." - ) - image_key = next(iter(image_keys)) - if ( - self.crop_shape[0] > self.input_shapes[image_key][1] - or self.crop_shape[1] > self.input_shapes[image_key][2] - ): - raise ValueError( - f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} " - f"for `crop_shape` and {self.input_shapes[image_key]} for " - "`input_shapes[{image_key}]`." - ) + if self.crop_shape is not None: + for image_key in image_keys: + if ( + self.crop_shape[0] > self.input_shapes[image_key][1] + or self.crop_shape[1] > self.input_shapes[image_key][2] + ): + raise ValueError( + f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} " + f"for `crop_shape` and {self.input_shapes[image_key]} for " + "`input_shapes[{image_key}]`." + ) + # Check that all input images have the same shape. + first_image_key = next(iter(image_keys)) + for image_key in image_keys: + if self.input_shapes[image_key] != self.input_shapes[first_image_key]: + raise ValueError( + f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we " + "expect all image shapes to match." + ) supported_prediction_types = ["epsilon", "sample"] if self.prediction_type not in supported_prediction_types: raise ValueError( diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 2ae03f22..279a1567 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -18,7 +18,6 @@ TODO(alexander-soare): - Remove reliance on diffusers for DDPMScheduler and LR scheduler. - - Make compatible with multiple image keys. """ import math @@ -83,20 +82,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): self.diffusion = DiffusionModel(config) - image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] - # Note: This check is covered in the post-init of the config but have a sanity check just in case. - if len(image_keys) != 1: - raise NotImplementedError( - f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." - ) - self.input_image_key = image_keys[0] + self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] self.reset() def reset(self): """Clear observation and action queues. Should be called on `env.reset()`""" self._queues = { - "observation.image": deque(maxlen=self.config.n_obs_steps), + "observation.images": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps), "action": deque(maxlen=self.config.n_action_steps), } @@ -124,8 +117,8 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """ batch = self.normalize_inputs(batch) - batch["observation.image"] = batch[self.input_image_key] - + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + # Note: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) if len(self._queues["action"]) == 0: @@ -144,7 +137,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) - batch["observation.image"] = batch[self.input_image_key] + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) return {"loss": loss} @@ -169,9 +162,10 @@ class DiffusionModel(nn.Module): self.config = config self.rgb_encoder = DiffusionRgbEncoder(config) + num_images = len([k for k in config.input_shapes if k.startswith("observation.image")]) self.unet = DiffusionConditionalUnet1d( config, - global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim) + global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim * num_images) * config.n_obs_steps, ) @@ -220,23 +214,34 @@ class DiffusionModel(nn.Module): return sample + def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor: + """Encode image features and concatenate them all together along with the state vector.""" + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + # Extract image feature (first combine batch, sequence, and camera index dims). + img_features = self.rgb_encoder( + einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") + ) + # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the feature + # dim (effectively concatenating the camera features). + img_features = einops.rearrange( + img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps + ) + # Concatenate state and image features then flatten to (B, global_cond_dim). + return torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) + def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: """ This function expects `batch` to have: { "observation.state": (B, n_obs_steps, state_dim) - "observation.image": (B, n_obs_steps, C, H, W) + "observation.images": (B, n_obs_steps, num_cameras, C, H, W) } """ batch_size, n_obs_steps = batch["observation.state"].shape[:2] assert n_obs_steps == self.config.n_obs_steps - # Extract image feature (first combine batch and sequence dims). - img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) - # Separate batch and sequence dims. - img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) - # Concatenate state and image features then flatten to (B, global_cond_dim). - global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) + # Encode image features and concatenate them all together along with the state vector. + global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim) # run sampling sample = self.conditional_sample(batch_size, global_cond=global_cond) @@ -255,28 +260,23 @@ class DiffusionModel(nn.Module): This function expects `batch` to have (at least): { "observation.state": (B, n_obs_steps, state_dim) - "observation.image": (B, n_obs_steps, C, H, W) + "observation.images": (B, n_obs_steps, num_cameras, C, H, W) "action": (B, horizon, action_dim) "action_is_pad": (B, horizon) } """ # Input validation. - assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"}) - batch_size, n_obs_steps = batch["observation.state"].shape[:2] + assert set(batch).issuperset({"observation.state", "observation.images", "action", "action_is_pad"}) + n_obs_steps = batch["observation.state"].shape[1] horizon = batch["action"].shape[1] assert horizon == self.config.horizon assert n_obs_steps == self.config.n_obs_steps - # Extract image feature (first combine batch and sequence dims). - img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) - # Separate batch and sequence dims. - img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) - # Concatenate state and image features then flatten to (B, global_cond_dim). - global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) - - trajectory = batch["action"] + # Encode image features and concatenate them all together along with the state vector. + global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim) # Forward diffusion. + trajectory = batch["action"] # Sample noise to add to the trajectory. eps = torch.randn(trajectory.shape, device=trajectory.device) # Sample a random noising timestep for each item in the batch. @@ -304,7 +304,12 @@ class DiffusionModel(nn.Module): loss = F.mse_loss(pred, target, reduction="none") # Mask loss wherever the action is padded with copies (edges of the dataset trajectory). - if self.config.do_mask_loss_for_padding and "action_is_pad" in batch: + if self.config.do_mask_loss_for_padding: + if "action_is_pad" not in batch: + raise ValueError( + "You need to provide 'action_is_pad' in the batch when " + f"{self.config.do_mask_loss_for_padding=}." + ) in_episode_bound = ~batch["action_is_pad"] loss = loss * in_episode_bound.unsqueeze(-1) @@ -425,7 +430,7 @@ class DiffusionRgbEncoder(nn.Module): # The dummy input should take the number of image channels from `config.input_shapes` and it should # use the height and width from `config.crop_shape`. image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] - assert len(image_keys) == 1 + # Note: we have a check in the config class to make sure all images have the same shape. image_key = image_keys[0] dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape)) with torch.inference_mode(): diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index cf76fb08..49485c39 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -31,6 +31,15 @@ class TDMPCConfig: n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google action repeats in Q-learning or ask your favorite chatbot) horizon: Horizon for model predictive control. + input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents + the input data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], + indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't + include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents + the output data name, and the value is a list indicating the dimensions of the corresponding data. + For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. + Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), and the value specifies the normalization mode to apply. The two available modes are "mean_std" which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a