make unit tests pass

This commit is contained in:
Cadene
2024-04-23 21:39:39 +00:00
parent 42ed7bb670
commit 0660f71556
13 changed files with 79 additions and 38 deletions

View File

@@ -6,7 +6,6 @@ import torch
from gymnasium.utils.env_checker import check_env
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.utils.utils import init_hydra_config
@@ -38,12 +37,14 @@ def test_factory(env_name):
overrides=[f"env={env_name}", f"device={DEVICE}"],
)
dataset = make_dataset(cfg)
env = make_env(cfg, num_parallel_envs=1)
obs, _ = env.reset()
obs = preprocess_observation(obs)
for key in dataset.image_keys:
# test image keys are float32 in range [0,1]
for key in obs:
if "image" not in key:
continue
img = obs[key]
assert img.dtype == torch.float32
# TODO(rcadene): we assume for now that image normalization takes place in the model