Make sure policies don't mutate the batch (#323)

This commit is contained in:
Alexander Soare
2024-07-22 20:38:33 +01:00
committed by GitHub
parent 0b21210d72
commit abbb1d2367
6 changed files with 27 additions and 5 deletions

View File

@@ -122,6 +122,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
@@ -143,6 +144,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)