fix training

This commit is contained in:
Cadene
2024-03-11 12:33:15 +00:00
parent 816b2e9d63
commit ccd5dc5a42
2 changed files with 16 additions and 2 deletions

View File

@@ -1,5 +1,7 @@
from torchrl.envs.transforms import StepCounter, TransformedEnv
from rl.torchrl.envs.transforms.transforms import Compose, Transform
def make_env(cfg, transform=None):
kwargs = {
@@ -33,7 +35,13 @@ def make_env(cfg, transform=None):
if transform is not None:
# useful to add normalization
env.append_transform(transform)
if isinstance(transform, Compose):
for tf in transform:
env.append_transform(tf.clone())
elif isinstance(transform, Transform):
env.append_transform(transform.clone())
else:
raise NotImplementedError()
return env