Make policies compatible with other/multiple image keys (#149)
This commit is contained in:
@@ -112,13 +112,12 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
def save(self, fp):
|
||||
"""Save state dict of TOLD model to filepath."""
|
||||
torch.save(self.state_dict(), fp)
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
assert len(image_keys) == 1
|
||||
self.input_image_key = image_keys[0]
|
||||
|
||||
def load(self, fp):
|
||||
"""Load a saved state dict from filepath into current agent."""
|
||||
self.load_state_dict(torch.load(fp))
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
@@ -137,10 +136,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
"""Select a single action given environment observations."""
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -319,13 +316,11 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
||||
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
|
||||
batch_size = batch["index"].shape[0]
|
||||
|
||||
# (b, t) -> (t, b)
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
@@ -353,6 +348,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||
# gives us a next `z`.
|
||||
batch_size = batch["index"].shape[0]
|
||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||
z_preds[0] = self.model.encode(current_observation)
|
||||
reward_preds = torch.empty_like(reward, device=device)
|
||||
|
||||
Reference in New Issue
Block a user