Move normalize/unnormalize transforms to policy for act and diffusion

This commit is contained in:
Cadene
2024-04-20 21:08:14 +00:00
parent c1bcf857c5
commit 42ed7bb670
19 changed files with 145 additions and 195 deletions

View File

@@ -42,7 +42,7 @@ def test_factory(env_name):
env = make_env(cfg, num_parallel_envs=1)
obs, _ = env.reset()
obs = preprocess_observation(obs, transform=dataset.transform)
obs = preprocess_observation(obs)
for key in dataset.image_keys:
img = obs[key]
assert img.dtype == torch.float32