test_envs are passing

This commit is contained in:
Cadene
2024-04-05 23:27:12 +00:00
parent 5eff40b3d6
commit 44656d2706
7 changed files with 91 additions and 99 deletions

View File

@@ -6,12 +6,20 @@ from lerobot.common.transforms import apply_inverse_transform
def preprocess_observation(observation, transform=None):
# map to expected inputs for the policy
obs = {
"observation.image": torch.from_numpy(observation["pixels"]).float(),
"observation.state": torch.from_numpy(observation["agent_pos"]).float(),
}
# convert to (b c h w) torch format
obs["observation.image"] = einops.rearrange(obs["observation.image"], "b h w c -> b c h w")
obs = {}
if isinstance(observation["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()}
else:
imgs = {"observation.image": observation["pixels"]}
for imgkey, img in imgs.items():
img = torch.from_numpy(img).float()
# convert to (b c h w) torch format
img = einops.rearrange(img, "b h w c -> b c h w")
obs[imgkey] = img
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
# apply same transforms as in training
if transform is not None: