feat(sim): Add Libero Env (#1984)
This commit is contained in:
@@ -46,7 +46,10 @@ def test_env(env_name, env_task, obs_type):
|
||||
@require_env
|
||||
def test_factory(env_name):
|
||||
cfg = make_env_config(env_name)
|
||||
env = make_env(cfg, n_envs=1)
|
||||
envs = make_env(cfg, n_envs=1)
|
||||
suite_name = next(iter(envs))
|
||||
task_id = next(iter(envs[suite_name]))
|
||||
env = envs[suite_name][task_id]
|
||||
obs, _ = env.reset()
|
||||
obs = preprocess_observation(obs)
|
||||
|
||||
|
||||
@@ -159,7 +159,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
assert isinstance(policy, PreTrainedPolicy)
|
||||
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
env = make_env(train_cfg.env, n_envs=2)
|
||||
envs = make_env(train_cfg.env, n_envs=2)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
@@ -188,6 +188,12 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
|
||||
# reset the policy and environment
|
||||
policy.reset()
|
||||
# For testing purposes, we only need a single environment instance.
|
||||
# So here we unwrap the first suite_name and first task_id to grab
|
||||
# the actual env object (SyncVectorEnv) that exposes `.reset()`.
|
||||
suite_name = next(iter(envs))
|
||||
task_id = next(iter(envs[suite_name]))
|
||||
env = envs[suite_name][task_id]
|
||||
observation, _ = env.reset(seed=train_cfg.seed)
|
||||
|
||||
# apply transform to normalize the observations
|
||||
|
||||
Reference in New Issue
Block a user