fix online training

This commit is contained in:
Cadene
2024-04-16 16:07:39 +00:00
parent 4a3eac4743
commit e09d25267e
2 changed files with 14 additions and 3 deletions

View File

@@ -14,7 +14,7 @@ def preprocess_observation(observation, transform=None):
imgs = {"observation.image": observation["pixels"]}
for imgkey, img in imgs.items():
img = torch.from_numpy(img).float()
img = torch.from_numpy(img)
# convert to (b c h w) torch format
img = einops.rearrange(img, "b h w c -> b c h w")
obs[imgkey] = img