Make sure policies don't mutate the batch (#323)

This commit is contained in:
Alexander Soare
2024-07-22 20:38:33 +01:00
committed by GitHub
parent 0b21210d72
commit abbb1d2367
6 changed files with 27 additions and 5 deletions

View File

@@ -101,6 +101,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
@@ -128,6 +129,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@@ -467,10 +469,9 @@ class ACT(nn.Module):
if self.use_images:
all_cam_features = []
all_cam_pos_embeds = []
images = batch["observation.images"]
for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
for cam_index in range(batch["observation.images"].shape[-4]):
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
# buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)

View File

@@ -122,6 +122,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
@@ -143,6 +144,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)

View File

@@ -132,6 +132,7 @@ class Normalize(nn.Module):
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, mode in self.modes.items():
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
@@ -197,6 +198,7 @@ class Unnormalize(nn.Module):
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, mode in self.modes.items():
buffer = getattr(self, "buffer_" + key.replace(".", "_"))

View File

@@ -137,6 +137,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch)
@@ -316,6 +317,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch)

View File

@@ -98,6 +98,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
"""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
@@ -123,6 +124,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)