Add Normalize, non_blocking=True in tdmpc, tdmpc run (TODO: diffusion)
This commit is contained in:
@@ -3,7 +3,7 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||
from lerobot.common.envs.transforms import Prod
|
||||
|
||||
|
||||
def make_env(cfg):
|
||||
def make_env(cfg, transform=None):
|
||||
kwargs = {
|
||||
"frame_skip": cfg.env.action_repeat,
|
||||
"from_pixels": cfg.env.from_pixels,
|
||||
@@ -32,6 +32,10 @@ def make_env(cfg):
|
||||
# to ensure pusht is in [0,255] like simxarm
|
||||
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))
|
||||
|
||||
if transform is not None:
|
||||
# useful to add mean and std normalization
|
||||
env.append_transform(transform)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user