Add Normalize, non_blocking=True in tdmpc, tdmpc run (TODO: diffusion)

This commit is contained in:
Remi Cadene
2024-03-02 15:53:29 +00:00
parent b5a2f460ea
commit 1ae6205269
9 changed files with 183 additions and 67 deletions

View File

@@ -3,7 +3,7 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv
from lerobot.common.envs.transforms import Prod
def make_env(cfg):
def make_env(cfg, transform=None):
kwargs = {
"frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.env.from_pixels,
@@ -32,6 +32,10 @@ def make_env(cfg):
# to ensure pusht is in [0,255] like simxarm
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))
if transform is not None:
# useful to add mean and std normalization
env.append_transform(transform)
return env