Files
lerobot_piper/lerobot/common/envs/factory.py
Cadene 1cdfbc8b52 WIP
WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
2024-04-04 15:31:03 +00:00

41 lines
1.4 KiB
Python

import gymnasium as gym
def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
"""
Note: When `num_parallel_envs > 0`, this function returns a `SyncVectorEnv` which takes batched action as input and
returns batched observation, reward, terminated, truncated of `num_parallel_envs` items.
"""
kwargs = {}
if cfg.env.name == "simxarm":
kwargs["task"] = cfg.env.task
elif cfg.env.name == "pusht":
import gym_pusht # noqa
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
kwargs.update(
{
"obs_type": "pixels_agent_pos",
"render_action": False,
}
)
env_fn = lambda: gym.make( # noqa: E731
"gym_pusht/PushTPixels-v0",
render_mode="rgb_array",
max_episode_steps=cfg.env.episode_length,
**kwargs,
)
elif cfg.env.name == "aloha":
kwargs["task"] = cfg.env.task
else:
raise ValueError(cfg.env.name)
if num_parallel_envs == 0:
# non-batched version of the env that returns an observation of shape (c)
env = env_fn()
else:
# batched version of the env that returns an observation of shape (b, c)
env = gym.vector.SyncVectorEnv([env_fn for _ in range(num_parallel_envs)])
return env