Move normalization to policy for act and diffusion (#90)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -6,7 +6,6 @@ import torch
|
||||
from gymnasium.utils.env_checker import check_env
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
@@ -38,12 +37,14 @@ def test_factory(env_name):
|
||||
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||
)
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
env = make_env(cfg, num_parallel_envs=1)
|
||||
obs, _ = env.reset()
|
||||
obs = preprocess_observation(obs, transform=dataset.transform)
|
||||
for key in dataset.image_keys:
|
||||
obs = preprocess_observation(obs)
|
||||
|
||||
# test image keys are float32 in range [0,1]
|
||||
for key in obs:
|
||||
if "image" not in key:
|
||||
continue
|
||||
img = obs[key]
|
||||
assert img.dtype == torch.float32
|
||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
||||
|
||||
Reference in New Issue
Block a user