fix training
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||
|
||||
from rl.torchrl.envs.transforms.transforms import Compose, Transform
|
||||
|
||||
|
||||
def make_env(cfg, transform=None):
|
||||
kwargs = {
|
||||
@@ -33,7 +35,13 @@ def make_env(cfg, transform=None):
|
||||
|
||||
if transform is not None:
|
||||
# useful to add normalization
|
||||
env.append_transform(transform)
|
||||
if isinstance(transform, Compose):
|
||||
for tf in transform:
|
||||
env.append_transform(tf.clone())
|
||||
elif isinstance(transform, Transform):
|
||||
env.append_transform(transform.clone())
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return env
|
||||
|
||||
|
||||
Reference in New Issue
Block a user