Merge pull request #17 from Cadene/user/rcadene/2024_03_11_bugfix_compute_stats

Fix bugs with normalization
This commit is contained in:
Remi
2024-03-11 13:44:07 +01:00
committed by GitHub
6 changed files with 27 additions and 10 deletions

View File

@@ -1,4 +1,4 @@
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
def make_env(cfg, transform=None):
@@ -33,7 +33,13 @@ def make_env(cfg, transform=None):
if transform is not None:
# useful to add normalization
env.append_transform(transform)
if isinstance(transform, Compose):
for tf in transform:
env.append_transform(tf.clone())
elif isinstance(transform, Transform):
env.append_transform(transform.clone())
else:
raise NotImplementedError()
return env