Merge remote-tracking branch 'upstream/user/rcadene/2024_03_31_remove_torchrl' into refactor_act

This commit is contained in:
Alexander Soare
2024-04-09 15:19:29 +01:00
7 changed files with 82 additions and 68 deletions

View File

@@ -30,10 +30,13 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
if num_parallel_envs == 0:
# non-batched version of the env that returns an observation of shape (c)
env = gym.make(gym_handle, **kwargs)
env = gym.make(gym_handle, disable_env_checker=True, **kwargs)
else:
# batched version of the env that returns an observation of shape (b, c)
env = gym.vector.SyncVectorEnv(
[lambda: gym.make(gym_handle, **kwargs) for _ in range(num_parallel_envs)]
[
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
for _ in range(num_parallel_envs)
]
)
return env