test_envs are passing

This commit is contained in:
Cadene
2024-04-05 23:27:12 +00:00
parent 5eff40b3d6
commit 44656d2706
7 changed files with 91 additions and 99 deletions

View File

@@ -16,7 +16,6 @@ from lerobot.common.envs.aloha.tasks.sim_end_effector import (
TransferCubeEndEffectorTask,
)
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
from lerobot.common.utils import set_global_seed
class AlohaEnv(gym.Env):
@@ -55,15 +54,20 @@ class AlohaEnv(gym.Env):
elif self.obs_type == "pixels_agent_pos":
self.observation_space = spaces.Dict(
{
"pixels": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
"pixels": spaces.Dict(
{
"top": spaces.Box(
low=0,
high=255,
shape=(self.observation_height, self.observation_width, 3),
dtype=np.uint8,
)
}
),
"agent_pos": spaces.Box(
low=np.array([-1] * len(JOINTS)), # ???
high=np.array([1] * len(JOINTS)), # ???
low=-np.inf,
high=np.inf,
shape=(len(JOINTS),),
dtype=np.float64,
),
}
@@ -89,21 +93,21 @@ class AlohaEnv(gym.Env):
if "transfer_cube" in task_name:
xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = TransferCubeTask(random=False)
task = TransferCubeTask()
elif "insertion" in task_name:
xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = InsertionTask(random=False)
task = InsertionTask()
elif "end_effector_transfer_cube" in task_name:
raise NotImplementedError()
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = TransferCubeEndEffectorTask(random=False)
task = TransferCubeEndEffectorTask()
elif "end_effector_insertion" in task_name:
raise NotImplementedError()
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = InsertionEndEffectorTask(random=False)
task = InsertionEndEffectorTask()
else:
raise NotImplementedError(task_name)
@@ -116,10 +120,10 @@ class AlohaEnv(gym.Env):
if self.obs_type == "state":
raise NotImplementedError()
elif self.obs_type == "pixels":
obs = raw_obs["images"]["top"].copy()
obs = {"top": raw_obs["images"]["top"].copy()}
elif self.obs_type == "pixels_agent_pos":
obs = {
"pixels": raw_obs["images"]["top"].copy(),
"pixels": {"top": raw_obs["images"]["top"].copy()},
"agent_pos": raw_obs["qpos"],
}
return obs
@@ -129,14 +133,14 @@ class AlohaEnv(gym.Env):
# TODO(rcadene): how to seed the env?
if seed is not None:
set_global_seed(seed)
self._env.task.random.seed(seed)
self._env.task._random = np.random.RandomState(seed)
# TODO(rcadene): do not use global variable for this
if "transfer_cube" in self.task:
BOX_POSE[0] = sample_box_pose() # used in sim reset
BOX_POSE[0] = sample_box_pose(seed) # used in sim reset
elif "insertion" in self.task:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
BOX_POSE[0] = np.concatenate(sample_insertion_pose(seed)) # used in sim reset
else:
raise ValueError(self.task)