Loads episode_data_index and stats during dataset __init__ (#85)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-23 14:13:25 +02:00
committed by GitHub
parent e2168163cd
commit 1030ea0070
89 changed files with 1008 additions and 432 deletions

View File

@@ -15,8 +15,19 @@ def preprocess_observation(observation, transform=None):
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
# convert to (b c h w) torch format
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w")
img = img.type(torch.float32)
img /= 255
obs[imgkey] = img
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"