Merge pull request #17 from Cadene/user/rcadene/2024_03_11_bugfix_compute_stats
Fix bugs with normalization
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
|
||||
|
||||
|
||||
def make_env(cfg, transform=None):
|
||||
@@ -33,7 +33,13 @@ def make_env(cfg, transform=None):
|
||||
|
||||
if transform is not None:
|
||||
# useful to add normalization
|
||||
env.append_transform(transform)
|
||||
if isinstance(transform, Compose):
|
||||
for tf in transform:
|
||||
env.append_transform(tf.clone())
|
||||
elif isinstance(transform, Transform):
|
||||
env.append_transform(transform.clone())
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return env
|
||||
|
||||
|
||||
Reference in New Issue
Block a user