From 127de1258d2c4c2b9dc56b7fdc61106899978894 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sun, 24 Mar 2024 19:31:47 +0100 Subject: [PATCH] WIP --- .../common/envs/simxarm/simxarm/task/base.py | 15 ++++++-- .../common/envs/simxarm/simxarm/task/mocap.py | 37 +++++++++---------- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/lerobot/common/envs/simxarm/simxarm/task/base.py b/lerobot/common/envs/simxarm/simxarm/task/base.py index d91f61c1..f9829c2c 100644 --- a/lerobot/common/envs/simxarm/simxarm/task/base.py +++ b/lerobot/common/envs/simxarm/simxarm/task/base.py @@ -63,7 +63,8 @@ class Base(robot_env.MujocoRobotEnv): return self._get_obs() def _step_callback(self): - self.sim.forward() + # self.sim.forward() + self._mujoco.mj_forward(self.model, self.data) def _limit_gripper(self, gripper_pos, pos_ctrl): if gripper_pos[0] > self.center_of_table[0] - 0.105 + 0.15: @@ -88,7 +89,12 @@ class Base(robot_env.MujocoRobotEnv): self._utils.get_site_xpos(self.model, self.data, "grasp"), pos_ctrl ) * (1 / self.n_substeps) gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl]) - mocap.apply_action(self.sim, np.concatenate([pos_ctrl, self.gripper_rotation, gripper_ctrl])) + mocap.apply_action( + self.model, + self._model_names, + self.data, + np.concatenate([pos_ctrl, self.gripper_rotation, gripper_ctrl]), + ) def _viewer_setup(self): body_id = self.sim.model.body_name2id("link7") @@ -144,8 +150,9 @@ class Base(robot_env.MujocoRobotEnv): assert action.shape == (4,) assert self.action_space.contains(action), "{!r} ({}) invalid".format(action, type(action)) self._apply_action(action) - for _ in range(2): - self.sim.step() + # for _ in range(2): + # self.sim.step() + self._mujoco.mj_step(self.model, self.data, nstep=2) self._step_callback() obs = self._get_obs() reward = self.get_reward() diff --git a/lerobot/common/envs/simxarm/simxarm/task/mocap.py b/lerobot/common/envs/simxarm/simxarm/task/mocap.py index 45722f13..4295bf19 100644 --- a/lerobot/common/envs/simxarm/simxarm/task/mocap.py +++ b/lerobot/common/envs/simxarm/simxarm/task/mocap.py @@ -3,17 +3,17 @@ import mujoco import numpy as np -def apply_action(sim, action): - if sim.model.nmocap > 0: - pos_action, gripper_action = np.split(action, (sim.model.nmocap * 7,)) - if sim.data.ctrl is not None: +def apply_action(model, model_names, data, action): + if model.nmocap > 0: + pos_action, gripper_action = np.split(action, (model.nmocap * 7,)) + if data.ctrl is not None: for i in range(gripper_action.shape[0]): - sim.data.ctrl[i] = gripper_action[i] - pos_action = pos_action.reshape(sim.model.nmocap, 7) + data.ctrl[i] = gripper_action[i] + pos_action = pos_action.reshape(model.nmocap, 7) pos_delta, quat_delta = pos_action[:, :3], pos_action[:, 3:] - reset_mocap2body_xpos(sim) - sim.data.mocap_pos[:] = sim.data.mocap_pos + pos_delta - sim.data.mocap_quat[:] = sim.data.mocap_quat + quat_delta + reset_mocap2body_xpos(model, model_names, data) + data.mocap_pos[:] = data.mocap_pos + pos_delta + data.mocap_quat[:] = data.mocap_quat + quat_delta def reset(model, data): @@ -41,28 +41,27 @@ def reset(model, data): mujoco.mj_forward(model, data) -def reset_mocap2body_xpos(sim): - if sim.model.eq_type is None or sim.model.eq_obj1id is None or sim.model.eq_obj2id is None: +def reset_mocap2body_xpos(model, model_names, data): + if model.eq_type is None or model.eq_obj1id is None or model.eq_obj2id is None: return # For all weld constraints - for eq_type, obj1_id, obj2_id in zip( - sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id, strict=False - ): + for eq_type, obj1_id, obj2_id in zip(model.eq_type, model.eq_obj1id, model.eq_obj2id, strict=False): # if eq_type != mujoco_py.const.EQ_WELD: if eq_type != mujoco.mjtEq.mjEQ_WELD: continue - body2 = sim.model.body_id2name(obj2_id) + # body2 = model.body_id2name(obj2_id) + body2 = model_names.body_id2name[obj2_id] if body2 == "B0" or body2 == "B9" or body2 == "B1": continue - mocap_id = sim.model.body_mocapid[obj1_id] + mocap_id = model.body_mocapid[obj1_id] if mocap_id != -1: # obj1 is the mocap, obj2 is the welded body body_idx = obj2_id else: # obj2 is the mocap, obj1 is the welded body - mocap_id = sim.model.body_mocapid[obj2_id] + mocap_id = model.body_mocapid[obj2_id] body_idx = obj1_id assert mocap_id != -1 - sim.data.mocap_pos[mocap_id][:] = sim.data.body_xpos[body_idx] - sim.data.mocap_quat[mocap_id][:] = sim.data.body_xquat[body_idx] + data.mocap_pos[mocap_id][:] = data.xpos[body_idx] + data.mocap_quat[mocap_id][:] = data.xquat[body_idx]