Fix unit tests, Refactor, Add pusht env, (TODO pusht replay buffer, image preprocessing)
This commit is contained in:
@@ -1,17 +1,26 @@
|
||||
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||
|
||||
from lerobot.common.envs.pusht import PushtEnv
|
||||
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||
|
||||
|
||||
def make_env(cfg):
|
||||
assert cfg.env == "simxarm"
|
||||
env = SimxarmEnv(
|
||||
task=cfg.task,
|
||||
frame_skip=cfg.action_repeat,
|
||||
from_pixels=cfg.from_pixels,
|
||||
pixels_only=cfg.pixels_only,
|
||||
image_size=cfg.image_size,
|
||||
)
|
||||
kwargs = {
|
||||
"frame_skip": cfg.action_repeat,
|
||||
"from_pixels": cfg.from_pixels,
|
||||
"pixels_only": cfg.pixels_only,
|
||||
"image_size": cfg.image_size,
|
||||
}
|
||||
|
||||
if cfg.env == "simxarm":
|
||||
kwargs["task"] = cfg.task
|
||||
clsfunc = SimxarmEnv
|
||||
elif cfg.env == "pusht":
|
||||
clsfunc = PushtEnv
|
||||
else:
|
||||
raise ValueError(cfg.env)
|
||||
|
||||
env = clsfunc(**kwargs)
|
||||
|
||||
# limit rollout to max_steps
|
||||
env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length))
|
||||
|
||||
Reference in New Issue
Block a user