Make sure policies don't mutate the batch (#323)

This commit is contained in:
Alexander Soare
2024-07-22 20:38:33 +01:00
committed by GitHub
parent 0b21210d72
commit abbb1d2367
6 changed files with 27 additions and 5 deletions

View File

@@ -137,6 +137,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch)
@@ -316,6 +317,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch)