cherry pick me

This commit is contained in:
Alexander Soare
2024-05-28 21:37:46 +01:00
committed by Remi Cadene
parent 220b32441d
commit b47c07fbeb
4 changed files with 100 additions and 66 deletions

View File

@@ -25,6 +25,14 @@ class ACTConfig:
The parameters you will most likely need to change are the ones which depend on the environment / sensors. The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and 'output_shapes`. Those are: `input_shapes` and 'output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views.
Right now we only support all images having the same shape.
- "action" is required as an output key.
Args: Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back). current step and additional steps going back).
@@ -33,15 +41,15 @@ class ACTConfig:
This should be no greater than the chunk size. For example, if the chunk size size 100, you may This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out. environment, and throws the other 50 out.
input_shapes: A dictionary defining the shapes of the input data for the policy. input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
The key represents the input data name, and the value is a list indicating the dimensions the input data name, and the value is a list indicating the dimensions of the corresponding data.
of the corresponding data. For example, "observation.images.top" refers to an input from the For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
Importantly, shapes doesn't include batch dimension or temporal dimension. include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
The key represents the output data name, and the value is a list indicating the dimensions the output data name, and the value is a list indicating the dimensions of the corresponding data.
of the corresponding data. For example, "action" refers to an output shape of [14], indicating For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension. Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std" and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a

View File

@@ -26,21 +26,29 @@ class DiffusionConfig:
The parameters you will most likely need to change are the ones which depend on the environment / sensors. The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`. Those are: `input_shapes` and `output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views.
Right now we only support all images having the same shape.
- "action" is required as an output key.
Args: Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back). current step and additional steps going back).
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy. n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details. See `DiffusionPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy. input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
The key represents the input data name, and the value is a list indicating the dimensions the input data name, and the value is a list indicating the dimensions of the corresponding data.
of the corresponding data. For example, "observation.image" refers to an input from For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
Importantly, shapes doesnt include batch dimension or temporal dimension. include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
The key represents the output data name, and the value is a list indicating the dimensions the output data name, and the value is a list indicating the dimensions of the corresponding data.
of the corresponding data. For example, "action" refers to an output shape of [14], indicating For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std" and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
@@ -148,22 +156,26 @@ class DiffusionConfig:
raise ValueError( raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
) )
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1: if self.crop_shape is not None:
raise ValueError( for image_key in image_keys:
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}." if (
) self.crop_shape[0] > self.input_shapes[image_key][1]
image_key = next(iter(image_keys)) or self.crop_shape[1] > self.input_shapes[image_key][2]
if ( ):
self.crop_shape[0] > self.input_shapes[image_key][1] raise ValueError(
or self.crop_shape[1] > self.input_shapes[image_key][2] f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
): f"for `crop_shape` and {self.input_shapes[image_key]} for "
raise ValueError( "`input_shapes[{image_key}]`."
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} " )
f"for `crop_shape` and {self.input_shapes[image_key]} for " # Check that all input images have the same shape.
"`input_shapes[{image_key}]`." first_image_key = next(iter(image_keys))
) for image_key in image_keys:
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
raise ValueError(
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
"expect all image shapes to match."
)
supported_prediction_types = ["epsilon", "sample"] supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types: if self.prediction_type not in supported_prediction_types:
raise ValueError( raise ValueError(

View File

@@ -18,7 +18,6 @@
TODO(alexander-soare): TODO(alexander-soare):
- Remove reliance on diffusers for DDPMScheduler and LR scheduler. - Remove reliance on diffusers for DDPMScheduler and LR scheduler.
- Make compatible with multiple image keys.
""" """
import math import math
@@ -83,20 +82,14 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
self.diffusion = DiffusionModel(config) self.diffusion = DiffusionModel(config)
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
if len(image_keys) != 1:
raise NotImplementedError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
self.input_image_key = image_keys[0]
self.reset() self.reset()
def reset(self): def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`""" """Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = { self._queues = {
"observation.image": deque(maxlen=self.config.n_obs_steps), "observation.images": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps), "action": deque(maxlen=self.config.n_action_steps),
} }
@@ -124,8 +117,8 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
""" """
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key] batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0: if len(self._queues["action"]) == 0:
@@ -144,7 +137,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key] batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch) loss = self.diffusion.compute_loss(batch)
return {"loss": loss} return {"loss": loss}
@@ -169,9 +162,10 @@ class DiffusionModel(nn.Module):
self.config = config self.config = config
self.rgb_encoder = DiffusionRgbEncoder(config) self.rgb_encoder = DiffusionRgbEncoder(config)
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
self.unet = DiffusionConditionalUnet1d( self.unet = DiffusionConditionalUnet1d(
config, config,
global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim) global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim * num_images)
* config.n_obs_steps, * config.n_obs_steps,
) )
@@ -220,23 +214,34 @@ class DiffusionModel(nn.Module):
return sample return sample
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
"""Encode image features and concatenate them all together along with the state vector."""
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
# Extract image feature (first combine batch, sequence, and camera index dims).
img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the feature
# dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
)
# Concatenate state and image features then flatten to (B, global_cond_dim).
return torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
""" """
This function expects `batch` to have: This function expects `batch` to have:
{ {
"observation.state": (B, n_obs_steps, state_dim) "observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W) "observation.images": (B, n_obs_steps, num_cameras, C, H, W)
} }
""" """
batch_size, n_obs_steps = batch["observation.state"].shape[:2] batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.config.n_obs_steps assert n_obs_steps == self.config.n_obs_steps
# Extract image feature (first combine batch and sequence dims). # Encode image features and concatenate them all together along with the state vector.
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
# run sampling # run sampling
sample = self.conditional_sample(batch_size, global_cond=global_cond) sample = self.conditional_sample(batch_size, global_cond=global_cond)
@@ -255,28 +260,23 @@ class DiffusionModel(nn.Module):
This function expects `batch` to have (at least): This function expects `batch` to have (at least):
{ {
"observation.state": (B, n_obs_steps, state_dim) "observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W) "observation.images": (B, n_obs_steps, num_cameras, C, H, W)
"action": (B, horizon, action_dim) "action": (B, horizon, action_dim)
"action_is_pad": (B, horizon) "action_is_pad": (B, horizon)
} }
""" """
# Input validation. # Input validation.
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"}) assert set(batch).issuperset({"observation.state", "observation.images", "action", "action_is_pad"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2] n_obs_steps = batch["observation.state"].shape[1]
horizon = batch["action"].shape[1] horizon = batch["action"].shape[1]
assert horizon == self.config.horizon assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps assert n_obs_steps == self.config.n_obs_steps
# Extract image feature (first combine batch and sequence dims). # Encode image features and concatenate them all together along with the state vector.
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
trajectory = batch["action"]
# Forward diffusion. # Forward diffusion.
trajectory = batch["action"]
# Sample noise to add to the trajectory. # Sample noise to add to the trajectory.
eps = torch.randn(trajectory.shape, device=trajectory.device) eps = torch.randn(trajectory.shape, device=trajectory.device)
# Sample a random noising timestep for each item in the batch. # Sample a random noising timestep for each item in the batch.
@@ -304,7 +304,12 @@ class DiffusionModel(nn.Module):
loss = F.mse_loss(pred, target, reduction="none") loss = F.mse_loss(pred, target, reduction="none")
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory). # Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
if self.config.do_mask_loss_for_padding and "action_is_pad" in batch: if self.config.do_mask_loss_for_padding:
if "action_is_pad" not in batch:
raise ValueError(
"You need to provide 'action_is_pad' in the batch when "
f"{self.config.do_mask_loss_for_padding=}."
)
in_episode_bound = ~batch["action_is_pad"] in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1) loss = loss * in_episode_bound.unsqueeze(-1)
@@ -425,7 +430,7 @@ class DiffusionRgbEncoder(nn.Module):
# The dummy input should take the number of image channels from `config.input_shapes` and it should # The dummy input should take the number of image channels from `config.input_shapes` and it should
# use the height and width from `config.crop_shape`. # use the height and width from `config.crop_shape`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1 # Note: we have a check in the config class to make sure all images have the same shape.
image_key = image_keys[0] image_key = image_keys[0]
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape)) dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
with torch.inference_mode(): with torch.inference_mode():

View File

@@ -31,6 +31,15 @@ class TDMPCConfig:
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
action repeats in Q-learning or ask your favorite chatbot) action repeats in Q-learning or ask your favorite chatbot)
horizon: Horizon for model predictive control. horizon: Horizon for model predictive control.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std" and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a