Fix unit tests, Refactor, Add pusht env, (TODO pusht replay buffer, image preprocessing)

This commit is contained in:
Cadene
2024-02-20 12:26:57 +00:00
parent fdfb2010fd
commit 3da6ffb2cb
10 changed files with 559 additions and 89 deletions

View File

@@ -1,17 +1,26 @@
from torchrl.envs.transforms import StepCounter, TransformedEnv
from lerobot.common.envs.pusht import PushtEnv
from lerobot.common.envs.simxarm import SimxarmEnv
def make_env(cfg):
assert cfg.env == "simxarm"
env = SimxarmEnv(
task=cfg.task,
frame_skip=cfg.action_repeat,
from_pixels=cfg.from_pixels,
pixels_only=cfg.pixels_only,
image_size=cfg.image_size,
)
kwargs = {
"frame_skip": cfg.action_repeat,
"from_pixels": cfg.from_pixels,
"pixels_only": cfg.pixels_only,
"image_size": cfg.image_size,
}
if cfg.env == "simxarm":
kwargs["task"] = cfg.task
clsfunc = SimxarmEnv
elif cfg.env == "pusht":
clsfunc = PushtEnv
else:
raise ValueError(cfg.env)
env = clsfunc(**kwargs)
# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length))