Make policies compatible with other/multiple image keys (#149)

This commit is contained in:
Alexander Soare
2024-05-16 13:51:53 +01:00
committed by GitHub
parent f52f4f2cd2
commit 68c1b13406
9 changed files with 107 additions and 69 deletions

View File

@@ -147,12 +147,18 @@ class TDMPCConfig:
def __post_init__(self):
"""Input validation (not exhaustive)."""
if self.input_shapes["observation.image"][-2] != self.input_shapes["observation.image"][-1]:
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) != 1:
raise ValueError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
image_key = next(iter(image_keys))
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(
"Only square images are handled now. Got image shape "
f"{self.input_shapes['observation.image']}."
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
)
if self.n_gaussian_samples <= 0:
raise ValueError(