diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_ee_insertion.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml similarity index 100% rename from lerobot/common/envs/aloha/assets/bimanual_viperx_ee_insertion.xml rename to lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_ee_transfer_cube.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml similarity index 100% rename from lerobot/common/envs/aloha/assets/bimanual_viperx_ee_transfer_cube.xml rename to lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml diff --git a/lerobot/common/envs/aloha/constants.py b/lerobot/common/envs/aloha/constants.py index 082d3a6c0..e582e5f30 100644 --- a/lerobot/common/envs/aloha/constants.py +++ b/lerobot/common/envs/aloha/constants.py @@ -26,8 +26,6 @@ JOINTS = [ "right_arm_gripper", ] -# TODO(rcadene): this is for end to end, not when we control end effector -# TODO(rcadene): dimension names are wrong ACTIONS = [ # position and quaternion for end effector "left_arm_waist", @@ -36,19 +34,16 @@ ACTIONS = [ "left_arm_forearm_roll", "left_arm_wrist_angle", "left_arm_wrist_rotate", - "left_arm_left_finger", # normalized gripper position (0: close, 1: open) - "left_arm_right_finger", - # position and quaternion for end effector + "left_arm_gripper", "right_arm_waist", "right_arm_shoulder", "right_arm_elbow", "right_arm_forearm_roll", "right_arm_wrist_angle", "right_arm_wrist_rotate", - "right_arm_left_finger", # normalized gripper position (0: close, 1: open) - "right_arm_right_finger", + "right_arm_gripper", ] diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index d92c2f492..acb30b325 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -1,4 +1,3 @@ -import collections import importlib import logging from collections import deque @@ -9,7 +8,6 @@ import numpy as np import torch from dm_control import mujoco from dm_control.rl import control -from dm_control.suite import base from tensordict import TensorDict from torchrl.data.tensor_specs import ( BoundedTensorSpec, @@ -19,293 +17,24 @@ from torchrl.data.tensor_specs import ( ) from torchrl.envs import EnvBase -from lerobot.common.utils import set_seed - -from .constants import ( +from lerobot.common.envs.aloha.constants import ( ACTIONS, ASSETS_DIR, DT, JOINTS, - PUPPET_GRIPPER_POSITION_CLOSE, - START_ARM_POSE, - normalize_puppet_gripper_position, - normalize_puppet_gripper_velocity, - unnormalize_puppet_gripper_position, ) +from lerobot.common.envs.aloha.tasks.sim import BOX_POSE, InsertionTask, TransferCubeTask +from lerobot.common.envs.aloha.tasks.sim_end_effector import ( + InsertionEndEffectorTask, + TransferCubeEndEffectorTask, +) +from lerobot.common.utils import set_seed + from .utils import sample_box_pose, sample_insertion_pose _has_gym = importlib.util.find_spec("gym") is not None -# def make_ee_sim_env(task_name): -# """ -# Environment for simulated robot bi-manual manipulation, with end-effector control. -# Action space: [left_arm_pose (7), # position and quaternion for end effector -# left_gripper_positions (1), # normalized gripper position (0: close, 1: open) -# right_arm_pose (7), # position and quaternion for end effector -# right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) - -# Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position -# left_gripper_position (1), # normalized gripper position (0: close, 1: open) -# right_arm_qpos (6), # absolute joint position -# right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) -# "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) -# left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) -# right_arm_qvel (6), # absolute joint velocity (rad) -# right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) -# "images": {"main": (480x640x3)} # h, w, c, dtype='uint8' -# """ -# if "sim_transfer_cube" in task_name: -# xml_path = ASSETS_DIR / "bimanual_viperx_ee_transfer_cube.xml" -# physics = mujoco.Physics.from_xml_path(xml_path) -# task = TransferCubeEETask(random=False) -# env = control.Environment( -# physics, task, time_limit=20, control_timestep=DT, n_sub_steps=None, flat_observation=False -# ) -# elif "sim_insertion" in task_name: -# xml_path = ASSETS_DIR / "bimanual_viperx_ee_insertion.xml" -# physics = mujoco.Physics.from_xml_path(xml_path) -# task = InsertionEETask(random=False) -# env = control.Environment( -# physics, task, time_limit=20, control_timestep=DT, n_sub_steps=None, flat_observation=False -# ) -# else: -# raise NotImplementedError -# return env - - -class BimanualViperXEETask(base.Task): - def __init__(self, random=None): - super().__init__(random=random) - - def before_step(self, action, physics): - a_len = len(action) // 2 - action_left = action[:a_len] - action_right = action[a_len:] - - # set mocap position and quat - # left - np.copyto(physics.data.mocap_pos[0], action_left[:3]) - np.copyto(physics.data.mocap_quat[0], action_left[3:7]) - # right - np.copyto(physics.data.mocap_pos[1], action_right[:3]) - np.copyto(physics.data.mocap_quat[1], action_right[3:7]) - - # set gripper - g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7]) - g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7]) - np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl])) - - def initialize_robots(self, physics): - # reset joint position - physics.named.data.qpos[:16] = START_ARM_POSE - - # reset mocap to align with end effector - # to obtain these numbers: - # (1) make an ee_sim env and reset to the same start_pose - # (2) get env._physics.named.data.xpos['vx300s_left/gripper_link'] - # get env._physics.named.data.xquat['vx300s_left/gripper_link'] - # repeat the same for right side - np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084]) - np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0]) - # right - np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084])) - np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0]) - - # reset gripper control - close_gripper_control = np.array( - [ - PUPPET_GRIPPER_POSITION_CLOSE, - -PUPPET_GRIPPER_POSITION_CLOSE, - PUPPET_GRIPPER_POSITION_CLOSE, - -PUPPET_GRIPPER_POSITION_CLOSE, - ] - ) - np.copyto(physics.data.ctrl, close_gripper_control) - - def initialize_episode(self, physics): - """Sets the state of the environment at the start of each episode.""" - super().initialize_episode(physics) - - @staticmethod - def get_qpos(physics): - qpos_raw = physics.data.qpos.copy() - left_qpos_raw = qpos_raw[:8] - right_qpos_raw = qpos_raw[8:16] - left_arm_qpos = left_qpos_raw[:6] - right_arm_qpos = right_qpos_raw[:6] - left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])] - right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])] - return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) - - @staticmethod - def get_qvel(physics): - qvel_raw = physics.data.qvel.copy() - left_qvel_raw = qvel_raw[:8] - right_qvel_raw = qvel_raw[8:16] - left_arm_qvel = left_qvel_raw[:6] - right_arm_qvel = right_qvel_raw[:6] - left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])] - right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])] - return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) - - @staticmethod - def get_env_state(physics): - raise NotImplementedError - - def get_observation(self, physics): - # note: it is important to do .copy() - obs = collections.OrderedDict() - obs["qpos"] = self.get_qpos(physics) - obs["qvel"] = self.get_qvel(physics) - obs["env_state"] = self.get_env_state(physics) - obs["images"] = {} - obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top") - obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle") - obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close") - # used in scripted policy to obtain starting pose - obs["mocap_pose_left"] = np.concatenate( - [physics.data.mocap_pos[0], physics.data.mocap_quat[0]] - ).copy() - obs["mocap_pose_right"] = np.concatenate( - [physics.data.mocap_pos[1], physics.data.mocap_quat[1]] - ).copy() - - # used when replaying joint trajectory - obs["gripper_ctrl"] = physics.data.ctrl.copy() - return obs - - def get_reward(self, physics): - raise NotImplementedError - - -class TransferCubeEETask(BimanualViperXEETask): - def __init__(self, random=None): - super().__init__(random=random) - self.max_reward = 4 - - def initialize_episode(self, physics): - """Sets the state of the environment at the start of each episode.""" - self.initialize_robots(physics) - # randomize box position - cube_pose = sample_box_pose() - box_start_idx = physics.model.name2id("red_box_joint", "joint") - np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose) - # print(f"randomized cube position to {cube_position}") - - super().initialize_episode(physics) - - @staticmethod - def get_env_state(physics): - env_state = physics.data.qpos.copy()[16:] - return env_state - - def get_reward(self, physics): - # return whether left gripper is holding the box - all_contact_pairs = [] - for i_contact in range(physics.data.ncon): - id_geom_1 = physics.data.contact[i_contact].geom1 - id_geom_2 = physics.data.contact[i_contact].geom2 - name_geom_1 = physics.model.id2name(id_geom_1, "geom") - name_geom_2 = physics.model.id2name(id_geom_2, "geom") - contact_pair = (name_geom_1, name_geom_2) - all_contact_pairs.append(contact_pair) - - touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs - touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs - touch_table = ("red_box", "table") in all_contact_pairs - - reward = 0 - if touch_right_gripper: - reward = 1 - if touch_right_gripper and not touch_table: # lifted - reward = 2 - if touch_left_gripper: # attempted transfer - reward = 3 - if touch_left_gripper and not touch_table: # successful transfer - reward = 4 - return reward - - -class InsertionEETask(BimanualViperXEETask): - def __init__(self, random=None): - super().__init__(random=random) - self.max_reward = 4 - - def initialize_episode(self, physics): - """Sets the state of the environment at the start of each episode.""" - self.initialize_robots(physics) - # randomize peg and socket position - peg_pose, socket_pose = sample_insertion_pose() - - def id2index(j_id): - return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky - - peg_start_id = physics.model.name2id("red_peg_joint", "joint") - peg_start_idx = id2index(peg_start_id) - np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose) - # print(f"randomized cube position to {cube_position}") - - socket_start_id = physics.model.name2id("blue_socket_joint", "joint") - socket_start_idx = id2index(socket_start_id) - np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose) - # print(f"randomized cube position to {cube_position}") - - super().initialize_episode(physics) - - @staticmethod - def get_env_state(physics): - env_state = physics.data.qpos.copy()[16:] - return env_state - - def get_reward(self, physics): - # return whether peg touches the pin - all_contact_pairs = [] - for i_contact in range(physics.data.ncon): - id_geom_1 = physics.data.contact[i_contact].geom1 - id_geom_2 = physics.data.contact[i_contact].geom2 - name_geom_1 = physics.model.id2name(id_geom_1, "geom") - name_geom_2 = physics.model.id2name(id_geom_2, "geom") - contact_pair = (name_geom_1, name_geom_2) - all_contact_pairs.append(contact_pair) - - touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs - touch_left_gripper = ( - ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs - or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs - or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs - or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs - ) - - peg_touch_table = ("red_peg", "table") in all_contact_pairs - socket_touch_table = ( - ("socket-1", "table") in all_contact_pairs - or ("socket-2", "table") in all_contact_pairs - or ("socket-3", "table") in all_contact_pairs - or ("socket-4", "table") in all_contact_pairs - ) - peg_touch_socket = ( - ("red_peg", "socket-1") in all_contact_pairs - or ("red_peg", "socket-2") in all_contact_pairs - or ("red_peg", "socket-3") in all_contact_pairs - or ("red_peg", "socket-4") in all_contact_pairs - ) - pin_touched = ("red_peg", "pin") in all_contact_pairs - - reward = 0 - if touch_left_gripper and touch_right_gripper: # touch both - reward = 1 - if ( - touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table) - ): # grasp both - reward = 2 - if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching - reward = 3 - if pin_touched: # successful insertion - reward = 4 - return reward - - class AlohaEnv(EnvBase): def __init__( self, @@ -320,6 +49,7 @@ class AlohaEnv(EnvBase): num_prev_action=0, ): super().__init__(device=device, batch_size=[]) + self.task = task self.frame_skip = frame_skip self.from_pixels = from_pixels self.pixels_only = pixels_only @@ -338,27 +68,7 @@ class AlohaEnv(EnvBase): if not from_pixels: raise NotImplementedError() - # time limit is controlled by StepCounter in factory - time_limit = float("inf") - - if "sim_transfer_cube" in task: - xml_path = ASSETS_DIR / "bimanual_viperx_ee_transfer_cube.xml" - physics = mujoco.Physics.from_xml_path(str(xml_path)) - task = TransferCubeEETask(random=False) - env = control.Environment( - physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False - ) - elif "sim_insertion" in task: - xml_path = ASSETS_DIR / "bimanual_viperx_ee_insertion.xml" - physics = mujoco.Physics.from_xml_path(str(xml_path)) - task = InsertionEETask(random=False) - env = control.Environment( - physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False - ) - else: - raise NotImplementedError - - self._env = env + self._env = self._make_env_task(task) self._make_spec() self._current_seed = self.set_seed(seed) @@ -375,6 +85,36 @@ class AlohaEnv(EnvBase): image = self._env.physics.render(height=height, width=width, camera_id="top") return image + def _make_env_task(self, task_name): + # time limit is controlled by StepCounter in env factory + time_limit = float("inf") + + if "sim_transfer_cube" in task_name: + xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = TransferCubeTask(random=False) + elif "sim_insertion" in task_name: + xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = InsertionTask(random=False) + elif "sim_end_effector_transfer_cube" in task_name: + raise NotImplementedError() + xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = TransferCubeEndEffectorTask(random=False) + elif "sim_end_effector_insertion" in task_name: + raise NotImplementedError() + xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml" + physics = mujoco.Physics.from_xml_path(str(xml_path)) + task = InsertionEndEffectorTask(random=False) + else: + raise NotImplementedError(task_name) + + env = control.Environment( + physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False + ) + return env + def _format_raw_obs(self, raw_obs): if self.from_pixels: image = torch.from_numpy(raw_obs["images"]["top"].copy()) @@ -396,6 +136,13 @@ class AlohaEnv(EnvBase): # we need to handle seed iteration, since self._env.reset() rely an internal _seed. self._current_seed += 1 self.set_seed(self._current_seed) + + # TODO(rcadene): do not use global variable for this + if "sim_transfer_cube" in self.task: + BOX_POSE[0] = sample_box_pose() # used in sim reset + elif "sim_insertion" in self.task: + BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset + raw_obs = self._env.reset() # TODO(rcadene): add assert # assert self._current_seed == self._env._seed diff --git a/lerobot/common/envs/aloha/tasks/sim.py b/lerobot/common/envs/aloha/tasks/sim.py new file mode 100644 index 000000000..ee1d0927b --- /dev/null +++ b/lerobot/common/envs/aloha/tasks/sim.py @@ -0,0 +1,219 @@ +import collections + +import numpy as np +from dm_control.suite import base + +from lerobot.common.envs.aloha.constants import ( + START_ARM_POSE, + normalize_puppet_gripper_position, + normalize_puppet_gripper_velocity, + unnormalize_puppet_gripper_position, +) + +BOX_POSE = [None] # to be changed from outside + +""" +Environment for simulated robot bi-manual manipulation, with joint position control +Action space: [left_arm_qpos (6), # absolute joint position + left_gripper_positions (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) + +Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position + left_gripper_position (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) + "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) + left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) + right_arm_qvel (6), # absolute joint velocity (rad) + right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) + "images": {"main": (480x640x3)} # h, w, c, dtype='uint8' +""" + + +class BimanualViperXTask(base.Task): + def __init__(self, random=None): + super().__init__(random=random) + + def before_step(self, action, physics): + left_arm_action = action[:6] + right_arm_action = action[7 : 7 + 6] + normalized_left_gripper_action = action[6] + normalized_right_gripper_action = action[7 + 6] + + left_gripper_action = unnormalize_puppet_gripper_position(normalized_left_gripper_action) + right_gripper_action = unnormalize_puppet_gripper_position(normalized_right_gripper_action) + + full_left_gripper_action = [left_gripper_action, -left_gripper_action] + full_right_gripper_action = [right_gripper_action, -right_gripper_action] + + env_action = np.concatenate( + [left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action] + ) + super().before_step(env_action, physics) + return + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + super().initialize_episode(physics) + + @staticmethod + def get_qpos(physics): + qpos_raw = physics.data.qpos.copy() + left_qpos_raw = qpos_raw[:8] + right_qpos_raw = qpos_raw[8:16] + left_arm_qpos = left_qpos_raw[:6] + right_arm_qpos = right_qpos_raw[:6] + left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])] + right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])] + return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) + + @staticmethod + def get_qvel(physics): + qvel_raw = physics.data.qvel.copy() + left_qvel_raw = qvel_raw[:8] + right_qvel_raw = qvel_raw[8:16] + left_arm_qvel = left_qvel_raw[:6] + right_arm_qvel = right_qvel_raw[:6] + left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])] + right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])] + return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) + + @staticmethod + def get_env_state(physics): + raise NotImplementedError + + def get_observation(self, physics): + obs = collections.OrderedDict() + obs["qpos"] = self.get_qpos(physics) + obs["qvel"] = self.get_qvel(physics) + obs["env_state"] = self.get_env_state(physics) + obs["images"] = {} + obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top") + obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle") + obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close") + + return obs + + def get_reward(self, physics): + # return whether left gripper is holding the box + raise NotImplementedError + + +class TransferCubeTask(BimanualViperXTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside + # reset qpos, control and box position + with physics.reset_context(): + physics.named.data.qpos[:16] = START_ARM_POSE + np.copyto(physics.data.ctrl, START_ARM_POSE) + assert BOX_POSE[0] is not None + physics.named.data.qpos[-7:] = BOX_POSE[0] + # print(f"{BOX_POSE=}") + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether left gripper is holding the box + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, "geom") + name_geom_2 = physics.model.id2name(id_geom_2, "geom") + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_table = ("red_box", "table") in all_contact_pairs + + reward = 0 + if touch_right_gripper: + reward = 1 + if touch_right_gripper and not touch_table: # lifted + reward = 2 + if touch_left_gripper: # attempted transfer + reward = 3 + if touch_left_gripper and not touch_table: # successful transfer + reward = 4 + return reward + + +class InsertionTask(BimanualViperXTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside + # reset qpos, control and box position + with physics.reset_context(): + physics.named.data.qpos[:16] = START_ARM_POSE + np.copyto(physics.data.ctrl, START_ARM_POSE) + assert BOX_POSE[0] is not None + physics.named.data.qpos[-7 * 2 :] = BOX_POSE[0] # two objects + # print(f"{BOX_POSE=}") + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether peg touches the pin + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, "geom") + name_geom_2 = physics.model.id2name(id_geom_2, "geom") + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_left_gripper = ( + ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + ) + + peg_touch_table = ("red_peg", "table") in all_contact_pairs + socket_touch_table = ( + ("socket-1", "table") in all_contact_pairs + or ("socket-2", "table") in all_contact_pairs + or ("socket-3", "table") in all_contact_pairs + or ("socket-4", "table") in all_contact_pairs + ) + peg_touch_socket = ( + ("red_peg", "socket-1") in all_contact_pairs + or ("red_peg", "socket-2") in all_contact_pairs + or ("red_peg", "socket-3") in all_contact_pairs + or ("red_peg", "socket-4") in all_contact_pairs + ) + pin_touched = ("red_peg", "pin") in all_contact_pairs + + reward = 0 + if touch_left_gripper and touch_right_gripper: # touch both + reward = 1 + if ( + touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table) + ): # grasp both + reward = 2 + if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching + reward = 3 + if pin_touched: # successful insertion + reward = 4 + return reward diff --git a/lerobot/common/envs/aloha/tasks/sim_end_effector.py b/lerobot/common/envs/aloha/tasks/sim_end_effector.py new file mode 100644 index 000000000..d93c83306 --- /dev/null +++ b/lerobot/common/envs/aloha/tasks/sim_end_effector.py @@ -0,0 +1,263 @@ +import collections + +import numpy as np +from dm_control.suite import base + +from lerobot.common.envs.aloha.constants import ( + PUPPET_GRIPPER_POSITION_CLOSE, + START_ARM_POSE, + normalize_puppet_gripper_position, + normalize_puppet_gripper_velocity, + unnormalize_puppet_gripper_position, +) +from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose + +""" +Environment for simulated robot bi-manual manipulation, with end-effector control. +Action space: [left_arm_pose (7), # position and quaternion for end effector + left_gripper_positions (1), # normalized gripper position (0: close, 1: open) + right_arm_pose (7), # position and quaternion for end effector + right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) + +Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position + left_gripper_position (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) + "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) + left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) + right_arm_qvel (6), # absolute joint velocity (rad) + right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) + "images": {"main": (480x640x3)} # h, w, c, dtype='uint8' +""" + + +class BimanualViperXEndEffectorTask(base.Task): + def __init__(self, random=None): + super().__init__(random=random) + + def before_step(self, action, physics): + a_len = len(action) // 2 + action_left = action[:a_len] + action_right = action[a_len:] + + # set mocap position and quat + # left + np.copyto(physics.data.mocap_pos[0], action_left[:3]) + np.copyto(physics.data.mocap_quat[0], action_left[3:7]) + # right + np.copyto(physics.data.mocap_pos[1], action_right[:3]) + np.copyto(physics.data.mocap_quat[1], action_right[3:7]) + + # set gripper + g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7]) + g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7]) + np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl])) + + def initialize_robots(self, physics): + # reset joint position + physics.named.data.qpos[:16] = START_ARM_POSE + + # reset mocap to align with end effector + # to obtain these numbers: + # (1) make an ee_sim env and reset to the same start_pose + # (2) get env._physics.named.data.xpos['vx300s_left/gripper_link'] + # get env._physics.named.data.xquat['vx300s_left/gripper_link'] + # repeat the same for right side + np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084]) + np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0]) + # right + np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084])) + np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0]) + + # reset gripper control + close_gripper_control = np.array( + [ + PUPPET_GRIPPER_POSITION_CLOSE, + -PUPPET_GRIPPER_POSITION_CLOSE, + PUPPET_GRIPPER_POSITION_CLOSE, + -PUPPET_GRIPPER_POSITION_CLOSE, + ] + ) + np.copyto(physics.data.ctrl, close_gripper_control) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + super().initialize_episode(physics) + + @staticmethod + def get_qpos(physics): + qpos_raw = physics.data.qpos.copy() + left_qpos_raw = qpos_raw[:8] + right_qpos_raw = qpos_raw[8:16] + left_arm_qpos = left_qpos_raw[:6] + right_arm_qpos = right_qpos_raw[:6] + left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])] + right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])] + return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) + + @staticmethod + def get_qvel(physics): + qvel_raw = physics.data.qvel.copy() + left_qvel_raw = qvel_raw[:8] + right_qvel_raw = qvel_raw[8:16] + left_arm_qvel = left_qvel_raw[:6] + right_arm_qvel = right_qvel_raw[:6] + left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])] + right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])] + return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) + + @staticmethod + def get_env_state(physics): + raise NotImplementedError + + def get_observation(self, physics): + # note: it is important to do .copy() + obs = collections.OrderedDict() + obs["qpos"] = self.get_qpos(physics) + obs["qvel"] = self.get_qvel(physics) + obs["env_state"] = self.get_env_state(physics) + obs["images"] = {} + obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top") + obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle") + obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close") + # used in scripted policy to obtain starting pose + obs["mocap_pose_left"] = np.concatenate( + [physics.data.mocap_pos[0], physics.data.mocap_quat[0]] + ).copy() + obs["mocap_pose_right"] = np.concatenate( + [physics.data.mocap_pos[1], physics.data.mocap_quat[1]] + ).copy() + + # used when replaying joint trajectory + obs["gripper_ctrl"] = physics.data.ctrl.copy() + return obs + + def get_reward(self, physics): + raise NotImplementedError + + +class TransferCubeEndEffectorTask(BimanualViperXEndEffectorTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + self.initialize_robots(physics) + # randomize box position + cube_pose = sample_box_pose() + box_start_idx = physics.model.name2id("red_box_joint", "joint") + np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose) + # print(f"randomized cube position to {cube_position}") + + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether left gripper is holding the box + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, "geom") + name_geom_2 = physics.model.id2name(id_geom_2, "geom") + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_table = ("red_box", "table") in all_contact_pairs + + reward = 0 + if touch_right_gripper: + reward = 1 + if touch_right_gripper and not touch_table: # lifted + reward = 2 + if touch_left_gripper: # attempted transfer + reward = 3 + if touch_left_gripper and not touch_table: # successful transfer + reward = 4 + return reward + + +class InsertionEndEffectorTask(BimanualViperXEndEffectorTask): + def __init__(self, random=None): + super().__init__(random=random) + self.max_reward = 4 + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + self.initialize_robots(physics) + # randomize peg and socket position + peg_pose, socket_pose = sample_insertion_pose() + + def id2index(j_id): + return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky + + peg_start_id = physics.model.name2id("red_peg_joint", "joint") + peg_start_idx = id2index(peg_start_id) + np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose) + # print(f"randomized cube position to {cube_position}") + + socket_start_id = physics.model.name2id("blue_socket_joint", "joint") + socket_start_idx = id2index(socket_start_id) + np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose) + # print(f"randomized cube position to {cube_position}") + + super().initialize_episode(physics) + + @staticmethod + def get_env_state(physics): + env_state = physics.data.qpos.copy()[16:] + return env_state + + def get_reward(self, physics): + # return whether peg touches the pin + all_contact_pairs = [] + for i_contact in range(physics.data.ncon): + id_geom_1 = physics.data.contact[i_contact].geom1 + id_geom_2 = physics.data.contact[i_contact].geom2 + name_geom_1 = physics.model.id2name(id_geom_1, "geom") + name_geom_2 = physics.model.id2name(id_geom_2, "geom") + contact_pair = (name_geom_1, name_geom_2) + all_contact_pairs.append(contact_pair) + + touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs + touch_left_gripper = ( + ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs + ) + + peg_touch_table = ("red_peg", "table") in all_contact_pairs + socket_touch_table = ( + ("socket-1", "table") in all_contact_pairs + or ("socket-2", "table") in all_contact_pairs + or ("socket-3", "table") in all_contact_pairs + or ("socket-4", "table") in all_contact_pairs + ) + peg_touch_socket = ( + ("red_peg", "socket-1") in all_contact_pairs + or ("red_peg", "socket-2") in all_contact_pairs + or ("red_peg", "socket-3") in all_contact_pairs + or ("red_peg", "socket-4") in all_contact_pairs + ) + pin_touched = ("red_peg", "pin") in all_contact_pairs + + reward = 0 + if touch_left_gripper and touch_right_gripper: # touch both + reward = 1 + if ( + touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table) + ): # grasp both + reward = 2 + if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching + reward = 3 + if pin_touched: # successful insertion + reward = 4 + return reward diff --git a/lerobot/common/policies/act/detr_vae.py b/lerobot/common/policies/act/detr_vae.py index 272eb846f..2c9704308 100644 --- a/lerobot/common/policies/act/detr_vae.py +++ b/lerobot/common/policies/act/detr_vae.py @@ -27,7 +27,7 @@ def get_sinusoid_encoding_table(n_position, d_hid): class DETRVAE(nn.Module): """This is the DETR module that performs object detection""" - def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names): + def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names): """Initializes the model. Parameters: backbones: torch module of the backbone to be used. See backbone.py @@ -43,17 +43,18 @@ class DETRVAE(nn.Module): self.transformer = transformer self.encoder = encoder hidden_dim = transformer.d_model - self.action_head = nn.Linear(hidden_dim, state_dim) + self.action_head = nn.Linear(hidden_dim, action_dim) self.is_pad_head = nn.Linear(hidden_dim, 1) self.query_embed = nn.Embedding(num_queries, hidden_dim) if backbones is not None: self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) self.backbones = nn.ModuleList(backbones) - self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) else: # input_dim = 14 + 7 # robot_state + env_state - self.input_proj_robot_state = nn.Linear(14, hidden_dim) - self.input_proj_env_state = nn.Linear(7, hidden_dim) + self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) + # TODO(rcadene): understand what is env_state, and why it needs to be 7 + self.input_proj_env_state = nn.Linear(state_dim // 2, hidden_dim) self.pos = torch.nn.Embedding(2, hidden_dim) self.backbones = None @@ -180,8 +181,6 @@ def build_encoder(args): def build(args): - state_dim = 14 # TODO hardcode - # From state # backbone = None # from state for now, no need for conv nets # From image @@ -197,7 +196,8 @@ def build(args): backbones, transformer, encoder, - state_dim=state_dim, + state_dim=args.state_dim, + action_dim=args.action_dim, num_queries=args.num_queries, camera_names=args.camera_names, ) diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index 77d3d4a19..13f4c199e 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -25,29 +25,6 @@ def build_act_model_and_optimizer(cfg): return model, optimizer -# def build_CNNMLP_model_and_optimizer(cfg): -# parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) -# args = parser.parse_args() - -# for k, v in cfg.items(): -# setattr(args, k, v) - -# model = build_CNNMLP_model(args) -# model.cuda() - -# param_dicts = [ -# {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, -# { -# "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], -# "lr": args.lr_backbone, -# }, -# ] -# optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, -# weight_decay=args.weight_decay) - -# return model, optimizer - - def kl_divergence(mu, logvar): batch_size = mu.size(0) assert batch_size != 0 @@ -65,9 +42,10 @@ def kl_divergence(mu, logvar): class ActionChunkingTransformerPolicy(nn.Module): - def __init__(self, cfg, device): + def __init__(self, cfg, device, n_action_steps=1): super().__init__() self.cfg = cfg + self.n_action_steps = n_action_steps self.device = device self.model, self.optimizer = build_act_model_and_optimizer(cfg) self.kl_weight = self.cfg.kl_weight @@ -179,11 +157,34 @@ class ActionChunkingTransformerPolicy(nn.Module): observation["image"] = observation["image"].unsqueeze(0) observation["state"] = observation["state"].unsqueeze(0) + # TODO(rcadene): remove hack + # add 1 camera dimension + observation["image"] = observation["image"].unsqueeze(1) + obs_dict = { "image": observation["image"], "agent_pos": observation["state"], } action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"]) + + if self.cfg.temporal_agg: + # TODO(rcadene): implement temporal aggregation + raise NotImplementedError() + # all_time_actions[[t], t:t+num_queries] = action + # actions_for_curr_step = all_time_actions[:, t] + # actions_populated = torch.all(actions_for_curr_step != 0, axis=1) + # actions_for_curr_step = actions_for_curr_step[actions_populated] + # k = 0.01 + # exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) + # exp_weights = exp_weights / exp_weights.sum() + # exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1) + # raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) + + # remove bsize=1 + action = action.squeeze(0) + + # take first predicted action or n first actions + action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps] return action def _forward(self, qpos, image, actions=None, is_pad=None): @@ -209,46 +210,3 @@ class ActionChunkingTransformerPolicy(nn.Module): else: action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior return action - - -# class CNNMLPPolicy(nn.Module): -# def __init__(self, cfg): -# super().__init__() -# model, optimizer = build_CNNMLP_model_and_optimizer(cfg) -# self.model = model # decoder -# self.optimizer = optimizer - -# def __call__(self, qpos, image, actions=None, is_pad=None): -# env_state = None # TODO -# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], -# std=[0.229, 0.224, 0.225]) -# image = normalize(image) -# if actions is not None: # training time -# actions = actions[:, 0] -# a_hat = self.model(qpos, image, env_state, actions) -# mse = F.mse_loss(actions, a_hat) -# loss_dict = dict() -# loss_dict['mse'] = mse -# loss_dict['loss'] = loss_dict['mse'] -# return loss_dict -# else: # inference time -# a_hat = self.model(qpos, image, env_state) # no action, sample from prior -# return a_hat - -# def configure_optimizers(self): -# return self.optimizer - -# def kl_divergence(mu, logvar): -# batch_size = mu.size(0) -# assert batch_size != 0 -# if mu.data.ndimension() == 4: -# mu = mu.view(mu.size(0), mu.size(1)) -# if logvar.data.ndimension() == 4: -# logvar = logvar.view(logvar.size(0), logvar.size(1)) - -# klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) -# total_kld = klds.sum(1).mean(0, True) -# dimension_wise_kld = klds.mean(0) -# mean_kld = klds.mean(1).mean(0, True) - -# return total_kld, dimension_wise_kld, mean_kld diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 5ccd1fc46..98e8aa2f1 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -20,7 +20,9 @@ def make_policy(cfg): elif cfg.policy.name == "act": from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy - policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device) + policy = ActionChunkingTransformerPolicy( + cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps + ) else: raise ValueError(cfg.policy.name) diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index c0edbbe71..ceb8e87fe 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -21,5 +21,5 @@ env: fps: ${fps} policy: - state_dim: 2 - action_dim: 2 + state_dim: 14 + action_dim: 14 diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index 98bd92c58..358ed83cf 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -1,7 +1,5 @@ # @package _global_ -state_dim: 14 - offline_steps: 1344000 online_steps: 0 @@ -12,7 +10,9 @@ log_freq: 250 horizon: 100 n_obs_steps: 1 -n_action_steps: 1 +n_latency_steps: 0 +# when temporal_agg=False, n_action_steps=horizon +n_action_steps: ${horizon} policy: name: act @@ -48,3 +48,8 @@ policy: utd: 1 n_obs_steps: ${n_obs_steps} + + temporal_agg: false + + state_dim: ??? + action_dim: ???