Make policies compatible with other/multiple image keys (#149)
This commit is contained in:
@@ -147,12 +147,18 @@ class TDMPCConfig:
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]:
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if len(image_keys) != 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
image_key = next(iter(image_keys))
|
||||
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
"Only square images are handled now. Got image shape "
|
||||
f"{self.input_shapes['observation.image']}."
|
||||
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||
)
|
||||
if self.n_gaussian_samples <= 0:
|
||||
raise ValueError(
|
||||
|
||||
@@ -112,13 +112,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
def save(self, fp):
|
||||
"""Save state dict of TOLD model to filepath."""
|
||||
torch.save(self.state_dict(), fp)
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
assert len(image_keys) == 1
|
||||
self.input_image_key = image_keys[0]
|
||||
|
||||
def load(self, fp):
|
||||
"""Load a saved state dict from filepath into current agent."""
|
||||
self.load_state_dict(torch.load(fp))
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
@@ -137,10 +136,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
"""Select a single action given environment observations."""
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -319,13 +316,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
||||
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
|
||||
batch_size = batch["index"].shape[0]
|
||||
|
||||
# (b, t) -> (t, b)
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
@@ -353,6 +348,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||
# gives us a next `z`.
|
||||
batch_size = batch["index"].shape[0]
|
||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||
z_preds[0] = self.model.encode(current_observation)
|
||||
reward_preds = torch.empty_like(reward, device=device)
|
||||
|
||||
Reference in New Issue
Block a user