Files
mindbot/scripts/environments/teleoperation/teleop_xr_agent.py
2026-03-06 11:11:06 +08:00

403 lines
16 KiB
Python

#!/usr/bin/env python3
# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""
Script to run teleoperation with Isaac Lab manipulation environments using PICO XR Controllers.
This script uses XRoboToolkit to fetch XR controller poses and maps them to differential IK actions.
"""
import argparse
import logging
import sys
import os
from collections.abc import Callable
# Ensure xr_utils (next to this script) is importable when running directly
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from isaaclab.app import AppLauncher
logger = logging.getLogger(__name__)
# add argparse arguments
parser = argparse.ArgumentParser(
description="Teleoperation for Isaac Lab environments with PICO XR Controller."
)
parser.add_argument(
"--num_envs", type=int, default=1, help="Number of environments to simulate."
)
parser.add_argument(
"--task",
type=str,
default="Isaac-MindRobot-LeftArm-IK-Rel-v0",
help="Name of the task.",
)
parser.add_argument(
"--sensitivity", type=float, default=5.0, help="Sensitivity factor for pos/rot."
)
parser.add_argument(
"--arm",
type=str,
default="left",
choices=["left", "right"],
help="Which arm/controller to use.",
)
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args()
app_launcher_args = vars(args_cli)
# Disable some rendering settings to speed up
app_launcher_args["xr"] = False
# launch omniverse app
app_launcher = AppLauncher(app_launcher_args)
simulation_app = app_launcher.app
"""Rest everything follows."""
import gymnasium as gym
import numpy as np
import torch
from scipy.spatial.transform import Rotation as R
from isaaclab.envs import ManagerBasedRLEnvCfg
import isaaclab_tasks # noqa: F401
import mindbot.tasks # noqa: F401
from isaaclab_tasks.utils import parse_env_cfg
from xr_utils import XrClient, transform_xr_pose, quat_diff_as_rotvec_xyzw, is_valid_quaternion
from xr_utils.geometry import R_HEADSET_TO_WORLD
# =====================================================================
# Teleoperation Interface for XR
# =====================================================================
class XrTeleopController:
"""Teleop controller for PICO XR headset."""
def __init__(self, arm="left", pos_sensitivity=1.0, rot_sensitivity=0.3, is_absolute=False):
self.xr_client = XrClient()
self.pos_sensitivity = pos_sensitivity
self.rot_sensitivity = rot_sensitivity
self.arm = arm
self.is_absolute = is_absolute
self.controller_name = "left_controller" if arm == "left" else "right_controller"
self.grip_name = "left_grip" if arm == "left" else "right_grip"
self.trigger_name = "left_trigger" if arm == "left" else "right_trigger"
# Coordinate transform matrix
self.R_headset_world = R_HEADSET_TO_WORLD
# Raw XR tracking space poses
self.prev_xr_pos = None
self.prev_xr_quat = None
# Absolute target states
self.target_eef_pos = None
self.target_eef_quat = None # wxyz
self.grip_active = False
self.frame_count = 0
# Callbacks (like reset, etc)
self.callbacks = {}
def add_callback(self, name: str, func: Callable):
self.callbacks[name] = func
def reset(self):
self.prev_xr_pos = None
self.prev_xr_quat = None
self.grip_active = False
self.frame_count = 0
def close(self):
self.xr_client.close()
def get_zero_action(self, trigger, current_eef_pos=None, current_eef_quat=None):
if self.is_absolute:
action = torch.zeros(8)
# Stay at the current valid pose, or fallback to the cached target
if current_eef_pos is not None and current_eef_quat is not None:
action[:3] = torch.tensor(current_eef_pos.copy())
action[3:7] = torch.tensor(current_eef_quat.copy())
elif self.target_eef_pos is not None and self.target_eef_quat is not None:
action[:3] = torch.tensor(self.target_eef_pos.copy())
action[3:7] = torch.tensor(self.target_eef_quat.copy())
else:
action[3] = 1.0 # default w=1 for quat
action[7] = 1.0 if trigger > 0.5 else -1.0
else:
action = torch.zeros(7)
action[6] = 1.0 if trigger > 0.5 else -1.0
return action
def advance(self, current_eef_pos=None, current_eef_quat=None) -> torch.Tensor:
"""
Reads the XR controller.
Relative bounds return 7D action tensor: [dx, dy, dz, drx, dry, drz, gripper]
Absolute bounds return 8D action tensor: [x, y, z, qw, qx, qy, qz, gripper]
Note: current_eef_quat expects wxyz.
"""
# XR buttons check (e.g. A or B for reset)
try:
if self.xr_client.get_button("B") or self.xr_client.get_button("Y"):
if "RESET" in self.callbacks:
self.callbacks["RESET"]()
except:
pass
try:
raw_pose = self.xr_client.get_pose(self.controller_name)
grip = self.xr_client.get_key_value(self.grip_name)
trigger = self.xr_client.get_key_value(self.trigger_name)
except Exception as e:
return self.get_zero_action(0.0, current_eef_pos, current_eef_quat)
# Skip transformation if quaternion is invalid
if not is_valid_quaternion(raw_pose[3:]):
return self.get_zero_action(trigger, current_eef_pos, current_eef_quat)
# Transform XR pose pos directly via Matrix mapping for safety check (though we can map diffs next)
pos_w = self.R_headset_world @ raw_pose[:3]
# 握持键作为离合器 (Clutch)
if grip < 0.5:
self.prev_xr_pos = None
self.prev_xr_quat = None
self.grip_active = False
# Ensure target tracks the current pose while we are not grabbing
if self.is_absolute and current_eef_pos is not None and current_eef_quat is not None:
self.target_eef_pos = current_eef_pos.copy()
self.target_eef_quat = current_eef_quat.copy() # wxyz
return self.get_zero_action(trigger, current_eef_pos, current_eef_quat)
if not self.grip_active:
self.prev_xr_pos = raw_pose[:3].copy()
self.prev_xr_quat = raw_pose[3:].copy()
if self.is_absolute and current_eef_pos is not None and current_eef_quat is not None:
self.target_eef_pos = current_eef_pos.copy()
self.target_eef_quat = current_eef_quat.copy() # wxyz
self.grip_active = True
return self.get_zero_action(trigger, current_eef_pos, current_eef_quat)
xr_delta_pos = raw_pose[:3] - self.prev_xr_pos
pos_norm = np.linalg.norm(xr_delta_pos)
# PICO -> Isaac World mapping:
# XR +X (Right) -> World -Y
# XR +Y (Up) -> World +Z
# XR +Z (Back) -> World -X
raw_to_world_delta_pos = np.array([-xr_delta_pos[2], -xr_delta_pos[0], xr_delta_pos[1]])
delta_pos = raw_to_world_delta_pos * self.pos_sensitivity
# Rotation Delta
xr_delta_rot = quat_diff_as_rotvec_xyzw(self.prev_xr_quat, raw_pose[3:])
rot_norm = np.linalg.norm(xr_delta_rot)
raw_to_world_delta_rot = np.array([-xr_delta_rot[2], -xr_delta_rot[0], xr_delta_rot[1]])
delta_rot = raw_to_world_delta_rot * self.rot_sensitivity
# Gripper
gripper_action = 1.0 if trigger > 0.5 else -1.0
if self.is_absolute:
# Apply delta integrally to target pose
if self.target_eef_pos is None:
# Failsafe if we don't have eef_pos injected yet
self.target_eef_pos = np.zeros(3)
self.target_eef_quat = np.array([1.0, 0.0, 0.0, 0.0])
self.target_eef_pos += delta_pos
# Multiply rotation. Target quat is wxyz. Delta rot is rotvec in world frame.
# scipy Rotation uses xyzw ! So convert wxyz -> xyzw, multiply, then -> wxyz
target_r = R.from_quat([self.target_eef_quat[1], self.target_eef_quat[2], self.target_eef_quat[3], self.target_eef_quat[0]])
delta_r = R.from_rotvec(delta_rot)
new_r = delta_r * target_r # Apply delta rotation in global space
new_rot_xyzw = new_r.as_quat()
self.target_eef_quat = np.array([new_rot_xyzw[3], new_rot_xyzw[0], new_rot_xyzw[1], new_rot_xyzw[2]])
action = torch.tensor([
self.target_eef_pos[0], self.target_eef_pos[1], self.target_eef_pos[2],
self.target_eef_quat[0], self.target_eef_quat[1], self.target_eef_quat[2], self.target_eef_quat[3],
gripper_action], dtype=torch.float32)
self.prev_xr_pos = raw_pose[:3].copy()
self.prev_xr_quat = raw_pose[3:].copy()
else:
# Relative clamping limits (used by relative mode to avoid divergence)
max_pos_delta = 0.04
if pos_norm > max_pos_delta:
delta_pos = (raw_to_world_delta_pos * max_pos_delta / pos_norm) * self.pos_sensitivity
max_rot_delta = 0.02
if rot_norm > max_rot_delta:
delta_rot = (raw_to_world_delta_rot * max_rot_delta / rot_norm) * self.rot_sensitivity
action = torch.tensor([
delta_pos[0], delta_pos[1], delta_pos[2],
delta_rot[0], delta_rot[1], delta_rot[2],
gripper_action], dtype=torch.float32)
self.prev_xr_pos = raw_pose[:3].copy()
self.prev_xr_quat = raw_pose[3:].copy()
self.frame_count += 1
if self.frame_count % 30 == 0:
np.set_printoptions(precision=4, suppress=True, floatmode='fixed')
print("\n====================== [VR TELEOP DEBUG] ======================")
print(f"| Task Mode: {'ABSOLUTE' if self.is_absolute else 'RELATIVE'}")
print(f"| Raw VR Pos (OpenXR): {np.array(raw_pose[:3])}")
print(f"| Raw VR Quat (xyzw): {np.array(raw_pose[3:])}")
print(f"| XR Delta Pos (raw): {xr_delta_pos} (norm={pos_norm:.4f})")
print(f"| XR Delta Rot (raw): {xr_delta_rot} (norm={rot_norm:.4f})")
if self.is_absolute:
print(f"| Targ Pos (dx,dy,dz): {action[:3].numpy()}")
print(f"| Targ Quat (wxyz): {action[3:7].numpy()}")
else:
print(f"| Sent Action Pos (dx,dy,dz): {action[:3].numpy()}")
print(f"| Sent Action Rot (rx,ry,rz): {action[3:6].numpy()}")
print(f"| Gripper & Trigger: Grip={grip:.2f}, Trig={trigger:.2f}")
print("===============================================================")
return action
# =====================================================================
# Main Execution Loop
# =====================================================================
def main() -> None:
"""Run teleoperation with PICO XR Controller against Isaac Lab environment."""
# 1. Configuration parsing
env_cfg = parse_env_cfg(args_cli.task, num_envs=args_cli.num_envs)
env_cfg.env_name = args_cli.task
if not isinstance(env_cfg, ManagerBasedRLEnvCfg):
raise ValueError(f"Teleoperation requires ManagerBasedRLEnvCfg. Got: {type(env_cfg)}")
env_cfg.terminations.time_out = None
# 2. Environment creation
try:
env = gym.make(args_cli.task, cfg=env_cfg).unwrapped
except Exception as e:
logger.error(f"Failed to create environment '{args_cli.task}': {e}")
simulation_app.close()
return
# 3. Teleoperation Interface Initialization
is_abs_mode = "Abs" in args_cli.task
print(f"\n[INFO] Connecting to PICO XR Headset using {args_cli.arm} controller...")
print(f"[INFO] Using IK Mode: {'ABSOLUTE' if is_abs_mode else 'RELATIVE'}")
teleop_interface = XrTeleopController(
arm=args_cli.arm,
pos_sensitivity=args_cli.sensitivity,
rot_sensitivity=args_cli.sensitivity * (1.0 if is_abs_mode else 0.3), # Absolute rotation handles 1:1 better than relative
is_absolute=is_abs_mode
)
should_reset = False
def request_reset():
nonlocal should_reset
should_reset = True
print("[INFO] Reset requested via XR button.")
teleop_interface.add_callback("RESET", request_reset)
env.reset()
teleop_interface.reset()
print("\n" + "=" * 50)
print(" 🚀 Teleoperation Started!")
print(" 🎮 Use the TRIGGER to open/close gripper.")
print(" ✊ Hold GRIP button and move the controller to move the arm.")
print(" 🕹️ Press B or Y to reset the environment.")
print("=" * 50 + "\n")
device = env.unwrapped.device
sim_frame = 0
obs, _ = env.reset()
while simulation_app.is_running():
try:
with torch.inference_mode():
# Extract tracking pose to seed absolute IK
policy_obs = obs["policy"]
eef_pos = policy_obs["eef_pos"][0].cpu().numpy()
eef_quat = policy_obs["eef_quat"][0].cpu().numpy()
# Get action from XR Controller
action_np = teleop_interface.advance(current_eef_pos=eef_pos, current_eef_quat=eef_quat)
actions = action_np.unsqueeze(0).repeat(env.num_envs, 1).to(device)
# Step environment
obs, _, _, _, _ = env.step(actions)
# Print robot state every 30 frames
sim_frame += 1
if sim_frame % 30 == 0:
np.set_printoptions(precision=4, suppress=True, floatmode='fixed')
policy_obs = obs["policy"]
joint_pos = policy_obs["joint_pos"][0].cpu().numpy()
eef_pos = policy_obs["eef_pos"][0].cpu().numpy()
eef_quat = policy_obs["eef_quat"][0].cpu().numpy()
last_act = policy_obs["actions"][0].cpu().numpy()
# On first print, dump ALL joint names + positions to identify indices
if sim_frame == 30:
robot = env.unwrapped.scene["robot"]
jnames = robot.joint_names
print(f"\n{'='*70}")
print(f" ALL {len(jnames)} JOINT NAMES AND POSITIONS (relative)")
print(f"{'='*70}")
for i, name in enumerate(jnames):
print(f" [{i:2d}] {name:30s} = {joint_pos[i]:+.4f}")
print(f"{'='*70}")
# Find arm joint indices dynamically by looking at the first 6-7 joints that aren't fingers or hands
arm_idx = [i for i, n in enumerate(jnames) if "finger" not in n and "hand" not in n][:7]
print(f" Deduced arm indices: {arm_idx}")
print(f"{'='*70}\n")
# Get arm indices (cache-friendly: find once)
if not hasattr(env, '_arm_idx_cache'):
robot = env.unwrapped.scene["robot"]
jnames = robot.joint_names
env._arm_idx_cache = [i for i, n in enumerate(jnames) if "finger" not in n and "hand" not in n][:7]
arm_idx = env._arm_idx_cache
arm_joints = joint_pos[arm_idx]
print(f"\n---------------- [ROBOT STATE frame={sim_frame}] ----------------")
print(f"| Left Arm Joints (rad): {arm_joints}")
print(f"| EEF Pos (world): {eef_pos}")
print(f"| EEF Quat (world, wxyz): {eef_quat}")
print(f"| Last Action Sent: {last_act}")
print(f"----------------------------------------------------------------")
if should_reset:
env.reset()
teleop_interface.reset()
should_reset = False
except Exception as e:
logger.error(f"Error during simulation step: {e}")
break
teleop_interface.close()
env.close()
if __name__ == "__main__":
main()
simulation_app.close()