from copy import deepcopy import numpy as np from core.skills.base_skill import BaseSkill, register_skill from omegaconf import DictConfig from omni.isaac.core.controllers import BaseController from omni.isaac.core.robots.robot import Robot from omni.isaac.core.tasks import BaseTask from omni.isaac.core.utils.prims import get_prim_at_path from omni.isaac.core.utils.transformations import get_relative_transform from scipy.spatial.transform import Rotation as R from solver.planner import KPAMPlanner # pylint: disable=unused-argument @register_skill class Close(BaseSkill): def __init__(self, robot: Robot, controller: BaseController, task: BaseTask, cfg: DictConfig, *args, **kwargs): super().__init__() self.robot = robot self.controller = controller self.task = task self.stage = task.stage self.name = cfg["name"] art_obj_name = cfg["objects"][0] self.skill_cfg = cfg self.art_obj = task.objects[art_obj_name] self.planner_setting = cfg["planner_setting"] self.contact_pose_index = self.planner_setting["contact_pose_index"] self.success_threshold = self.planner_setting["success_threshold"] self.update_art_joint = self.planner_setting.get("update_art_joint", False) if kwargs: self.world = kwargs["world"] self.draw = kwargs["draw"] self.manip_list = [] # self.draw = True if self.skill_cfg.get("obj_info_path", None): self.art_obj.update_articulated_info(self.skill_cfg["obj_info_path"]) lr_arm = "left" if "left" in self.controller.robot_file else "right" self.fingers_link_contact_view = task.artcontact_views[robot.name][lr_arm][art_obj_name + "_fingers_link"] self.fingers_base_contact_view = task.artcontact_views[robot.name][lr_arm][art_obj_name + "_fingers_base"] self.forbid_collision_contact_view = task.artcontact_views[robot.name][lr_arm][ art_obj_name + "_forbid_collision" ] self.collision_valid = True self.process_valid = True self.success_mode = self.planner_setting.get("success_mode", "zero") def setup_kpam(self): self.planner = KPAMPlanner( env=self.world, robot=self.robot, object=self.art_obj, cfg_path=self.planner_setting, controller=self.controller, draw_points=self.draw, stage=self.stage, ) def simple_generate_manip_cmds(self): if self.skill_cfg.get("obj_info_path", None): self.art_obj.update_articulated_info(self.skill_cfg["obj_info_path"]) self.setup_kpam() traj_keyframes, sample_times = self.planner.get_keypose() if len(traj_keyframes) == 0 and len(sample_times) == 0: print("No keyframes found, return empty manip_list") self.manip_list = [] return T_world_base = get_relative_transform( get_prim_at_path(self.robot.base_path), get_prim_at_path(self.task.root_prim_path) ) self.traj_keyframes = traj_keyframes self.sample_times = sample_times if self.draw: for keypose in traj_keyframes: self.draw.draw_points([(T_world_base @ np.append(keypose[:3, 3], 1))[:3]], [(0, 0, 0, 1)], [7]) manip_list = [] # Update p_base_ee_cur, q_base_ee_cur = self.controller.get_ee_pose() ignore_substring = deepcopy(self.controller.ignore_substring + self.skill_cfg.get("ignore_substring", [])) cmd = ( p_base_ee_cur, q_base_ee_cur, "update_specific", {"ignore_substring": ignore_substring, "reference_prim_path": self.controller.reference_prim_path}, ) manip_list.append(cmd) for i in range(len(self.traj_keyframes)): p_base_ee_tgt = self.traj_keyframes[i][:3, 3] q_base_ee_tgt = R.from_matrix(self.traj_keyframes[i][:3, :3]).as_quat(scalar_first=True) cmd = (p_base_ee_tgt, q_base_ee_tgt, "close_gripper", {}) manip_list.append(cmd) if i == self.contact_pose_index - 1: p_base_ee = self.traj_keyframes[i][:3, 3] q_base_ee = R.from_matrix(self.traj_keyframes[i][:3, :3]).as_quat(scalar_first=True) ignore_substring = deepcopy( self.controller.ignore_substring + self.skill_cfg.get("ignore_substring", []) ) parent_name = self.art_obj.prim_path.split("/")[-2] ignore_substring.append(parent_name) cmd = ( p_base_ee, q_base_ee, "update_specific", {"ignore_substring": ignore_substring, "reference_prim_path": self.controller.reference_prim_path}, ) manip_list.append(cmd) self.manip_list = manip_list def update(self): curr_joint_p = self.art_obj._articulation_view.get_joint_positions()[:, self.art_obj.object_joint_index] if self.update_art_joint and self.is_success(): self.art_obj._articulation_view.set_joint_position_targets( positions=curr_joint_p, joint_indices=self.art_obj.object_joint_index ) def get_contact(self, contact_threshold=0.0): contact = {} fingers_link_contact = np.abs(self.fingers_link_contact_view.get_contact_force_matrix()).squeeze() fingers_link_contact = np.sum(fingers_link_contact, axis=-1) fingers_link_contact_indices = np.where(fingers_link_contact > contact_threshold)[0] contact["fingers_link"] = { "fingers_link_contact": fingers_link_contact, "fingers_link_contact_indices": fingers_link_contact_indices, } fingers_base_contact = np.abs(self.fingers_base_contact_view.get_contact_force_matrix()).squeeze() fingers_base_contact = np.sum(fingers_base_contact, axis=-1) fingers_base_contact_indices = np.where(fingers_base_contact > contact_threshold)[0] contact["fingers_base"] = { "fingers_base_contact": fingers_base_contact, "fingers_base_contact_indices": fingers_base_contact_indices, } forbid_collision_contact = np.abs(self.forbid_collision_contact_view.get_contact_force_matrix()).squeeze() forbid_collision_contact = np.sum(forbid_collision_contact, axis=-1) forbid_collision_contact_indices = np.where(forbid_collision_contact > contact_threshold)[0] contact["forbid_collision"] = { "forbid_collision_contact": forbid_collision_contact, "forbid_collision_contact_indices": forbid_collision_contact_indices, } return contact def is_feasible(self, th=5): return self.controller.num_plan_failed <= th def is_subtask_done(self, t_eps=1e-3, o_eps=5e-3): assert len(self.manip_list) != 0 p_base_ee_cur, q_base_ee_cur = self.controller.get_ee_pose() p_base_ee, q_base_ee, *_ = self.manip_list[0] diff_trans = np.linalg.norm(p_base_ee_cur - p_base_ee) diff_ori = 2 * np.arccos(min(abs(np.dot(q_base_ee_cur, q_base_ee)), 1.0)) pose_flag = np.logical_and( diff_trans < t_eps, diff_ori < o_eps, ) self.plan_flag = self.controller.num_last_cmd > 10 return np.logical_or(pose_flag, self.plan_flag) def is_done(self): if len(self.manip_list) == 0: return True if self.is_subtask_done(): self.manip_list.pop(0) print("POP one manip cmd") if self.is_success(): self.manip_list.clear() print("Close Done") return len(self.manip_list) == 0 def is_success(self): contact = self.get_contact() if self.skill_cfg.get("collision_valid", True): self.collision_valid = ( self.collision_valid and len(contact["forbid_collision"]["forbid_collision_contact_indices"]) == 0 and len(contact["fingers_base"]["fingers_base_contact_indices"]) == 0 ) if self.skill_cfg.get("process_valid", True): self.process_valid = np.max(np.abs(self.robot.get_joints_state().velocities)) < 5 and ( np.max(np.abs(self.art_obj.get_linear_velocity())) < 5 ) curr_joint_p = self.art_obj._articulation_view.get_joint_positions()[:, self.art_obj.object_joint_index] init_joint_p = self.art_obj.articulation_initial_joint_position print( "curr_joint_p: ", curr_joint_p, "init_joint_p: ", init_joint_p, "distance: ", np.abs(curr_joint_p - init_joint_p), "collision_valid :", self.collision_valid, "process_valid :", self.process_valid, ) if self.success_mode == "zero": return np.abs(curr_joint_p) <= self.success_threshold and self.collision_valid and self.process_valid elif self.success_mode == "dis_to_init": return ( np.abs(curr_joint_p - init_joint_p) >= np.abs(self.success_threshold) and self.collision_valid and self.process_valid )