Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl

This commit is contained in:
Alexander Soare
2024-04-05 11:41:11 +01:00
21 changed files with 1303 additions and 1370 deletions

View File

@@ -1,64 +1,40 @@
from torchrl.envs import SerialEnv
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
import gymnasium as gym
def make_env(cfg, transform=None):
def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
"""
Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying
environments. The env therefore returns batches.`
Note: When `num_parallel_envs > 0`, this function returns a `SyncVectorEnv` which takes batched action as input and
returns batched observation, reward, terminated, truncated of `num_parallel_envs` items.
"""
kwargs = {
"frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size,
"num_prev_obs": cfg.n_obs_steps - 1,
}
kwargs = {}
if cfg.env.name == "simxarm":
from lerobot.common.envs.simxarm.env import SimxarmEnv
kwargs["task"] = cfg.env.task
clsfunc = SimxarmEnv
elif cfg.env.name == "pusht":
from lerobot.common.envs.pusht.env import PushtEnv
import gym_pusht # noqa
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
clsfunc = PushtEnv
kwargs.update(
{
"obs_type": "pixels_agent_pos",
"render_action": False,
}
)
env_fn = lambda: gym.make( # noqa: E731
"gym_pusht/PushTPixels-v0",
render_mode="rgb_array",
max_episode_steps=cfg.env.episode_length,
**kwargs,
)
elif cfg.env.name == "aloha":
from lerobot.common.envs.aloha.env import AlohaEnv
kwargs["task"] = cfg.env.task
clsfunc = AlohaEnv
else:
raise ValueError(cfg.env.name)
def _make_env(seed):
nonlocal kwargs
kwargs["seed"] = seed
env = clsfunc(**kwargs)
# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
if transform is not None:
# useful to add normalization
if isinstance(transform, Compose):
for tf in transform:
env.append_transform(tf.clone())
elif isinstance(transform, Transform):
env.append_transform(transform.clone())
else:
raise NotImplementedError()
return env
return SerialEnv(
cfg.rollout_batch_size,
create_env_fn=_make_env,
create_env_kwargs=[
{"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
],
)
if num_parallel_envs == 0:
# non-batched version of the env that returns an observation of shape (c)
env = env_fn()
else:
# batched version of the env that returns an observation of shape (b, c)
env = gym.vector.SyncVectorEnv([env_fn for _ in range(num_parallel_envs)])
return env

View File

@@ -55,7 +55,7 @@ class SimxarmEnv(AbstractEnv):
if not _has_gym:
raise ImportError("Cannot import gymnasium.")
import gymnasium
import gymnasium as gym
from lerobot.common.envs.simxarm.simxarm import TASKS
@@ -65,7 +65,7 @@ class SimxarmEnv(AbstractEnv):
self._env = TASKS[self.task]["env"]()
num_actions = len(TASKS[self.task]["action_space"])
self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
if "w" not in TASKS[self.task]["action_space"]:
self._action_padding[-1] = 1.0