feat(sim): Add Libero Env (#1984)

This commit is contained in:
Jade Choghari
2025-09-22 15:36:20 +02:00
committed by GitHub
parent f7283193ea
commit 2538472781
13 changed files with 906 additions and 32 deletions

View File

@@ -46,7 +46,10 @@ def test_env(env_name, env_task, obs_type):
@require_env
def test_factory(env_name):
cfg = make_env_config(env_name)
env = make_env(cfg, n_envs=1)
envs = make_env(cfg, n_envs=1)
suite_name = next(iter(envs))
task_id = next(iter(envs[suite_name]))
env = envs[suite_name][task_id]
obs, _ = env.reset()
obs = preprocess_observation(obs)

View File

@@ -159,7 +159,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
assert isinstance(policy, PreTrainedPolicy)
# Check that we run select_actions and get the appropriate output.
env = make_env(train_cfg.env, n_envs=2)
envs = make_env(train_cfg.env, n_envs=2)
dataloader = torch.utils.data.DataLoader(
dataset,
@@ -188,6 +188,12 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
# reset the policy and environment
policy.reset()
# For testing purposes, we only need a single environment instance.
# So here we unwrap the first suite_name and first task_id to grab
# the actual env object (SyncVectorEnv) that exposes `.reset()`.
suite_name = next(iter(envs))
task_id = next(iter(envs[suite_name]))
env = envs[suite_name][task_id]
observation, _ = env.reset(seed=train_cfg.seed)
# apply transform to normalize the observations