Files
mindbot/scripts/environments/teleoperation/teleop_xr_agent.py
2026-03-05 22:41:56 +08:00

336 lines
13 KiB
Python

#!/usr/bin/env python3
# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""
Script to run teleoperation with Isaac Lab manipulation environments using PICO XR Controllers.
This script uses XRoboToolkit to fetch XR controller poses and maps them to differential IK actions.
"""
import argparse
import logging
import sys
import os
from collections.abc import Callable
# Ensure xr_utils (next to this script) is importable when running directly
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from isaaclab.app import AppLauncher
logger = logging.getLogger(__name__)
# add argparse arguments
parser = argparse.ArgumentParser(
description="Teleoperation for Isaac Lab environments with PICO XR Controller."
)
parser.add_argument(
"--num_envs", type=int, default=1, help="Number of environments to simulate."
)
parser.add_argument(
"--task",
type=str,
default="Isaac-MindRobot-LeftArm-IK-Rel-v0",
help="Name of the task.",
)
parser.add_argument(
"--sensitivity", type=float, default=5.0, help="Sensitivity factor for pos/rot."
)
parser.add_argument(
"--arm",
type=str,
default="left",
choices=["left", "right"],
help="Which arm/controller to use.",
)
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args()
app_launcher_args = vars(args_cli)
# Disable some rendering settings to speed up
app_launcher_args["xr"] = False
# launch omniverse app
app_launcher = AppLauncher(app_launcher_args)
simulation_app = app_launcher.app
"""Rest everything follows."""
import gymnasium as gym
import numpy as np
import torch
from isaaclab.envs import ManagerBasedRLEnvCfg
import isaaclab_tasks # noqa: F401
import mindbot.tasks # noqa: F401
from isaaclab_tasks.utils import parse_env_cfg
from xr_utils import XrClient, transform_xr_pose, quat_diff_as_rotvec_xyzw, is_valid_quaternion
from xr_utils.geometry import R_HEADSET_TO_WORLD
# =====================================================================
# Teleoperation Interface for XR
# =====================================================================
class XrTeleopController:
"""Teleop controller for PICO XR headset."""
def __init__(self, arm="left", pos_sensitivity=1.0, rot_sensitivity=0.3):
self.xr_client = XrClient()
self.pos_sensitivity = pos_sensitivity
self.rot_sensitivity = rot_sensitivity
self.arm = arm
self.controller_name = "left_controller" if arm == "left" else "right_controller"
self.grip_name = "left_grip" if arm == "left" else "right_grip"
self.trigger_name = "left_trigger" if arm == "left" else "right_trigger"
# Coordinate transform matrix
self.R_headset_world = R_HEADSET_TO_WORLD
# Raw XR tracking space poses (NOT transformed)
self.prev_xr_pos = None
self.prev_xr_quat = None
self.grip_active = False
self.frame_count = 0
# Callbacks (like reset, etc)
self.callbacks = {}
def add_callback(self, name: str, func: Callable):
self.callbacks[name] = func
def reset(self):
self.prev_xr_pos = None
self.prev_xr_quat = None
self.grip_active = False
self.frame_count = 0
def close(self):
self.xr_client.close()
def advance(self) -> torch.Tensor:
"""
Reads the XR controller and returns the 7D action tensor:
[dx, dy, dz, droll, dpitch, dyaw, gripper_cmd]
"""
# XR buttons check (e.g. A or B for reset)
try:
if self.xr_client.get_button("B") or self.xr_client.get_button("Y"):
if "RESET" in self.callbacks:
self.callbacks["RESET"]()
except:
pass
try:
raw_pose = self.xr_client.get_pose(self.controller_name)
grip = self.xr_client.get_key_value(self.grip_name)
trigger = self.xr_client.get_key_value(self.trigger_name)
except Exception as e:
return torch.zeros(7)
# Skip transformation if quaternion is invalid (e.g. before headset truly connects)
if not is_valid_quaternion(raw_pose[3:]):
action = torch.zeros(7)
action[6] = 1.0 if trigger > 0.5 else -1.0
return action
# Transform XR pose pos directly via Matrix mapping for safety check (though we can map diffs next)
pos_w = self.R_headset_world @ raw_pose[:3]
# 握持键作为离合器 (Clutch) - 按下 Grip 时才移动机械臂
if grip < 0.5:
self.prev_xr_pos = None
self.prev_xr_quat = None
self.grip_active = False
action = torch.zeros(7)
action[6] = 1.0 if trigger > 0.5 else -1.0
return action
if not self.grip_active:
# We strictly log original XR coordinates to avoid quaternion base frame corruption
self.prev_xr_pos = raw_pose[:3].copy()
self.prev_xr_quat = raw_pose[3:].copy()
self.grip_active = True
action = torch.zeros(7)
action[6] = 1.0 if trigger > 0.5 else -1.0
return action
# ========== 1. Position Delta (calculate in XR frame, map to World) ==========
xr_delta_pos = raw_pose[:3] - self.prev_xr_pos
# Clamp raw position delta to prevent spikes (max ~4cm per frame)
max_pos_delta = 0.04
pos_norm = np.linalg.norm(xr_delta_pos)
if pos_norm > max_pos_delta:
xr_delta_pos = xr_delta_pos * (max_pos_delta / pos_norm)
# PICO -> Isaac World mapping:
# XR +X (Right) -> World -Y (Right is -Left)
# XR +Y (Up) -> World +Z (Up is Up)
# XR +Z (Back) -> World -X (Back is -Forward)
delta_pos = np.array([-xr_delta_pos[2], -xr_delta_pos[0], xr_delta_pos[1]]) * self.pos_sensitivity
# ========== 2. Rotation Delta (calculate in XR frame, map to World) ==========
# compute pure relative angular difference in local XR tracking space
xr_delta_rot = quat_diff_as_rotvec_xyzw(self.prev_xr_quat, raw_pose[3:])
# Clamp raw rotation delta to prevent spikes (max ~0.02 rad = ~1.1° per frame)
# Keeping this small is critical: DLS IK can only solve small deltas as pure rotations.
# Large deltas cause the IK to swing the whole arm instead of rotating in-place.
max_rot_delta = 0.02
rot_norm = np.linalg.norm(xr_delta_rot)
if rot_norm > max_rot_delta:
xr_delta_rot = xr_delta_rot * (max_rot_delta / rot_norm)
# Same mapping rules apply to rotation axes:
# Rotating around XR's X (Right) -> Rotating around World's -Y (Right)
# Rotating around XR's Y (Up) -> Rotating around World's +Z (Up)
# Rotating around XR's Z (Back) -> Rotating around World's -X (Back)
delta_rot = np.array([-xr_delta_rot[2], -xr_delta_rot[0], xr_delta_rot[1]]) * self.rot_sensitivity
# Update cache
self.prev_xr_pos = raw_pose[:3].copy()
self.prev_xr_quat = raw_pose[3:].copy()
self.frame_count += 1
# ========== 3. Gripper ==========
gripper_action = 1.0 if trigger > 0.5 else -1.0
action = torch.tensor([delta_pos[0], delta_pos[1], delta_pos[2], delta_rot[0], delta_rot[1], delta_rot[2], gripper_action], dtype=torch.float32)
# ========== 4. Comprehensive Debug Log ==========
if self.frame_count % 30 == 0:
np.set_printoptions(precision=4, suppress=True, floatmode='fixed')
print("\n====================== [VR TELEOP DEBUG] ======================")
print(f"| Raw VR Pos (OpenXR): {np.array(raw_pose[:3])}")
print(f"| Raw VR Quat (xyzw): {np.array(raw_pose[3:])}")
print(f"| XR Delta Pos (raw): {xr_delta_pos} (norm={pos_norm:.4f})")
print(f"| XR Delta Rot (raw): {xr_delta_rot} (norm={rot_norm:.4f})")
print("|--------------------------------------------------------------")
print(f"| Sent Action Pos (dx,dy,dz): {action[:3].numpy()}")
print(f"| Sent Action Rot (rx,ry,rz): {action[3:6].numpy()}")
print(f"| Gripper & Trigger: Grip={grip:.2f}, Trig={trigger:.2f}")
print("===============================================================")
return action
# =====================================================================
# Main Execution Loop
# =====================================================================
def main() -> None:
"""Run teleoperation with PICO XR Controller against Isaac Lab environment."""
# 1. Configuration parsing
env_cfg = parse_env_cfg(args_cli.task, num_envs=args_cli.num_envs)
env_cfg.env_name = args_cli.task
if not isinstance(env_cfg, ManagerBasedRLEnvCfg):
raise ValueError(f"Teleoperation requires ManagerBasedRLEnvCfg. Got: {type(env_cfg)}")
env_cfg.terminations.time_out = None
# 2. Environment creation
try:
env = gym.make(args_cli.task, cfg=env_cfg).unwrapped
except Exception as e:
logger.error(f"Failed to create environment '{args_cli.task}': {e}")
simulation_app.close()
return
# 3. Teleoperation Interface Initialization
print(f"\n[INFO] Connecting to PICO XR Headset using {args_cli.arm} controller...")
teleop_interface = XrTeleopController(
arm=args_cli.arm,
pos_sensitivity=args_cli.sensitivity,
rot_sensitivity=args_cli.sensitivity * 0.3, # Rotation must be much gentler for DLS IK
)
should_reset = False
def request_reset():
nonlocal should_reset
should_reset = True
print("[INFO] Reset requested via XR button.")
teleop_interface.add_callback("RESET", request_reset)
env.reset()
teleop_interface.reset()
print("\n" + "=" * 50)
print(" 🚀 Teleoperation Started!")
print(" 🎮 Use the TRIGGER to open/close gripper.")
print(" ✊ Hold GRIP button and move the controller to move the arm.")
print(" 🕹️ Press B or Y to reset the environment.")
print("=" * 50 + "\n")
# 4. Simulation loop
device = env.unwrapped.device
sim_frame = 0
while simulation_app.is_running():
try:
with torch.inference_mode():
# Get action from XR Controller [1, 7]
action_np = teleop_interface.advance()
actions = action_np.unsqueeze(0).repeat(env.num_envs, 1).to(device)
# Step environment
obs, _, _, _, _ = env.step(actions)
# Print robot state every 30 frames
sim_frame += 1
if sim_frame % 30 == 0:
np.set_printoptions(precision=4, suppress=True, floatmode='fixed')
policy_obs = obs["policy"]
joint_pos = policy_obs["joint_pos"][0].cpu().numpy()
eef_pos = policy_obs["eef_pos"][0].cpu().numpy()
eef_quat = policy_obs["eef_quat"][0].cpu().numpy()
last_act = policy_obs["actions"][0].cpu().numpy()
# On first print, dump ALL joint names + positions to identify indices
if sim_frame == 30:
robot = env.unwrapped.scene["robot"]
jnames = robot.joint_names
print(f"\n{'='*70}")
print(f" ALL {len(jnames)} JOINT NAMES AND POSITIONS (relative)")
print(f"{'='*70}")
for i, name in enumerate(jnames):
print(f" [{i:2d}] {name:30s} = {joint_pos[i]:+.4f}")
print(f"{'='*70}")
# Find left arm joint indices
arm_idx = [i for i, n in enumerate(jnames) if n.startswith("l_joint")]
print(f" Left arm indices: {arm_idx}")
print(f"{'='*70}\n")
# Get arm indices (cache-friendly: find once)
if not hasattr(env, '_arm_idx_cache'):
robot = env.unwrapped.scene["robot"]
jnames = robot.joint_names
env._arm_idx_cache = [i for i, n in enumerate(jnames) if n.startswith("l_joint")]
arm_idx = env._arm_idx_cache
arm_joints = joint_pos[arm_idx]
print(f"\n---------------- [ROBOT STATE frame={sim_frame}] ----------------")
print(f"| Left Arm Joints (rad): {arm_joints}")
print(f"| EEF Pos (world): {eef_pos}")
print(f"| EEF Quat (world, wxyz): {eef_quat}")
print(f"| Last Action Sent: {last_act}")
print(f"----------------------------------------------------------------")
if should_reset:
env.reset()
teleop_interface.reset()
should_reset = False
except Exception as e:
logger.error(f"Error during simulation step: {e}")
break
teleop_interface.close()
env.close()
if __name__ == "__main__":
main()
simulation_app.close()