Make policies compatible with other/multiple image keys (#149)

This commit is contained in:
Alexander Soare
2024-05-16 13:51:53 +01:00
committed by GitHub
parent f52f4f2cd2
commit 68c1b13406
9 changed files with 107 additions and 69 deletions

View File

@@ -147,12 +147,18 @@ class TDMPCConfig:
def __post_init__(self):
"""Input validation (not exhaustive)."""
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]:
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(
"Only square images are handled now. Got image shape "
f"{self.input_shapes['observation.image']}."
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
)
if self.n_gaussian_samples <= 0:
raise ValueError(

View File

@@ -112,13 +112,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
config.output_shapes, config.output_normalization_modes, dataset_stats
)
def save(self, fp):
"""Save state dict of TOLD model to filepath."""
torch.save(self.state_dict(), fp)
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
assert len(image_keys) == 1
self.input_image_key = image_keys[0]
def load(self, fp):
"""Load a saved state dict from filepath into current agent."""
self.load_state_dict(torch.load(fp))
self.reset()
def reset(self):
"""
@@ -137,10 +136,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]):
"""Select a single action given environment observations."""
assert "observation.image" in batch
assert "observation.state" in batch
batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch)
@@ -319,13 +316,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch)
info = {}
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
batch_size = batch["index"].shape[0]
# (b, t) -> (t, b)
for key in batch:
if batch[key].ndim > 1:
@@ -353,6 +348,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
# Run latent rollout using the latent dynamics model and policy model.
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device)