This commit is contained in:
Simon Alibert
2024-03-24 19:31:47 +01:00
parent b905111895
commit 127de1258d
2 changed files with 29 additions and 23 deletions

View File

@@ -63,7 +63,8 @@ class Base(robot_env.MujocoRobotEnv):
return self._get_obs() return self._get_obs()
def _step_callback(self): def _step_callback(self):
self.sim.forward() # self.sim.forward()
self._mujoco.mj_forward(self.model, self.data)
def _limit_gripper(self, gripper_pos, pos_ctrl): def _limit_gripper(self, gripper_pos, pos_ctrl):
if gripper_pos[0] > self.center_of_table[0] - 0.105 + 0.15: if gripper_pos[0] > self.center_of_table[0] - 0.105 + 0.15:
@@ -88,7 +89,12 @@ class Base(robot_env.MujocoRobotEnv):
self._utils.get_site_xpos(self.model, self.data, "grasp"), pos_ctrl self._utils.get_site_xpos(self.model, self.data, "grasp"), pos_ctrl
) * (1 / self.n_substeps) ) * (1 / self.n_substeps)
gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl]) gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl])
mocap.apply_action(self.sim, np.concatenate([pos_ctrl, self.gripper_rotation, gripper_ctrl])) mocap.apply_action(
self.model,
self._model_names,
self.data,
np.concatenate([pos_ctrl, self.gripper_rotation, gripper_ctrl]),
)
def _viewer_setup(self): def _viewer_setup(self):
body_id = self.sim.model.body_name2id("link7") body_id = self.sim.model.body_name2id("link7")
@@ -144,8 +150,9 @@ class Base(robot_env.MujocoRobotEnv):
assert action.shape == (4,) assert action.shape == (4,)
assert self.action_space.contains(action), "{!r} ({}) invalid".format(action, type(action)) assert self.action_space.contains(action), "{!r} ({}) invalid".format(action, type(action))
self._apply_action(action) self._apply_action(action)
for _ in range(2): # for _ in range(2):
self.sim.step() # self.sim.step()
self._mujoco.mj_step(self.model, self.data, nstep=2)
self._step_callback() self._step_callback()
obs = self._get_obs() obs = self._get_obs()
reward = self.get_reward() reward = self.get_reward()

View File

@@ -3,17 +3,17 @@ import mujoco
import numpy as np import numpy as np
def apply_action(sim, action): def apply_action(model, model_names, data, action):
if sim.model.nmocap > 0: if model.nmocap > 0:
pos_action, gripper_action = np.split(action, (sim.model.nmocap * 7,)) pos_action, gripper_action = np.split(action, (model.nmocap * 7,))
if sim.data.ctrl is not None: if data.ctrl is not None:
for i in range(gripper_action.shape[0]): for i in range(gripper_action.shape[0]):
sim.data.ctrl[i] = gripper_action[i] data.ctrl[i] = gripper_action[i]
pos_action = pos_action.reshape(sim.model.nmocap, 7) pos_action = pos_action.reshape(model.nmocap, 7)
pos_delta, quat_delta = pos_action[:, :3], pos_action[:, 3:] pos_delta, quat_delta = pos_action[:, :3], pos_action[:, 3:]
reset_mocap2body_xpos(sim) reset_mocap2body_xpos(model, model_names, data)
sim.data.mocap_pos[:] = sim.data.mocap_pos + pos_delta data.mocap_pos[:] = data.mocap_pos + pos_delta
sim.data.mocap_quat[:] = sim.data.mocap_quat + quat_delta data.mocap_quat[:] = data.mocap_quat + quat_delta
def reset(model, data): def reset(model, data):
@@ -41,28 +41,27 @@ def reset(model, data):
mujoco.mj_forward(model, data) mujoco.mj_forward(model, data)
def reset_mocap2body_xpos(sim): def reset_mocap2body_xpos(model, model_names, data):
if sim.model.eq_type is None or sim.model.eq_obj1id is None or sim.model.eq_obj2id is None: if model.eq_type is None or model.eq_obj1id is None or model.eq_obj2id is None:
return return
# For all weld constraints # For all weld constraints
for eq_type, obj1_id, obj2_id in zip( for eq_type, obj1_id, obj2_id in zip(model.eq_type, model.eq_obj1id, model.eq_obj2id, strict=False):
sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id, strict=False
):
# if eq_type != mujoco_py.const.EQ_WELD: # if eq_type != mujoco_py.const.EQ_WELD:
if eq_type != mujoco.mjtEq.mjEQ_WELD: if eq_type != mujoco.mjtEq.mjEQ_WELD:
continue continue
body2 = sim.model.body_id2name(obj2_id) # body2 = model.body_id2name(obj2_id)
body2 = model_names.body_id2name[obj2_id]
if body2 == "B0" or body2 == "B9" or body2 == "B1": if body2 == "B0" or body2 == "B9" or body2 == "B1":
continue continue
mocap_id = sim.model.body_mocapid[obj1_id] mocap_id = model.body_mocapid[obj1_id]
if mocap_id != -1: if mocap_id != -1:
# obj1 is the mocap, obj2 is the welded body # obj1 is the mocap, obj2 is the welded body
body_idx = obj2_id body_idx = obj2_id
else: else:
# obj2 is the mocap, obj1 is the welded body # obj2 is the mocap, obj1 is the welded body
mocap_id = sim.model.body_mocapid[obj2_id] mocap_id = model.body_mocapid[obj2_id]
body_idx = obj1_id body_idx = obj1_id
assert mocap_id != -1 assert mocap_id != -1
sim.data.mocap_pos[mocap_id][:] = sim.data.body_xpos[body_idx] data.mocap_pos[mocap_id][:] = data.xpos[body_idx]
sim.data.mocap_quat[mocap_id][:] = sim.data.body_xquat[body_idx] data.mocap_quat[mocap_id][:] = data.xquat[body_idx]