116 lines
4.3 KiB
Python
116 lines
4.3 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2025, The Isaac Lab Project Developers.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
"""Dual-arm XR teleoperation agent with chassis control."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
from xr_utils import XrClient
|
|
|
|
from .base_agent import BaseTeleopAgent
|
|
from .xr_controller import XrTeleopController
|
|
from .chassis import ChassisController
|
|
from .frame_utils import convert_action_world_to_root
|
|
|
|
|
|
class DualArmXrAgent(BaseTeleopAgent):
|
|
"""Dual-arm teleoperation with chassis joystick control.
|
|
|
|
Action: left_arm(7) | wheel(4) | left_grip(1) | right_arm(7) | right_grip(1) = 20D
|
|
"""
|
|
|
|
def __init__(self, env, simulation_app, *,
|
|
pos_sensitivity: float = 1.0,
|
|
rot_sensitivity: float = 0.3,
|
|
base_speed: float = 5.0,
|
|
base_turn: float = 2.0,
|
|
drive_speed: float = 0.5,
|
|
drive_turn: float = 1.5,
|
|
debug_viewports: bool = True):
|
|
super().__init__(env, simulation_app, debug_viewports=debug_viewports)
|
|
|
|
self.shared_client = XrClient()
|
|
self.teleop_left = XrTeleopController(
|
|
arm="left", pos_sensitivity=pos_sensitivity,
|
|
rot_sensitivity=rot_sensitivity, xr_client=self.shared_client,
|
|
)
|
|
self.teleop_right = XrTeleopController(
|
|
arm="right", pos_sensitivity=pos_sensitivity,
|
|
rot_sensitivity=rot_sensitivity, xr_client=self.shared_client,
|
|
)
|
|
self.teleop_left.add_callback("RESET", self.request_reset)
|
|
|
|
self.chassis = ChassisController(
|
|
base_speed=base_speed, base_turn=base_turn,
|
|
drive_speed=drive_speed, drive_turn=drive_turn,
|
|
)
|
|
|
|
self._last_root_left = None
|
|
self._last_root_right = None
|
|
|
|
@property
|
|
def xr_client(self):
|
|
return self.shared_client
|
|
|
|
def _ik_action_term_names(self) -> list[str]:
|
|
return ["left_arm_action", "right_arm_action"]
|
|
|
|
def on_reset(self):
|
|
self.teleop_left.reset()
|
|
self.teleop_right.reset()
|
|
self._last_root_left = None
|
|
self._last_root_right = None
|
|
|
|
def assemble_action(self, obs) -> torch.Tensor:
|
|
policy_obs = obs["policy"]
|
|
robot = self.env.unwrapped.scene["robot"]
|
|
|
|
# Read chassis
|
|
wheel_cmd, self._v_fwd, self._omega = self.chassis.get_commands(self.shared_client)
|
|
|
|
# Left arm
|
|
eef_pos_left = policy_obs["eef_pos_left"][0].cpu().numpy()
|
|
eef_quat_left = policy_obs["eef_quat_left"][0].cpu().numpy()
|
|
left_action = self.teleop_left.advance(current_eef_pos=eef_pos_left, current_eef_quat=eef_quat_left)
|
|
|
|
# Right arm
|
|
eef_pos_right = policy_obs["eef_pos_right"][0].cpu().numpy()
|
|
eef_quat_right = policy_obs["eef_quat_right"][0].cpu().numpy()
|
|
right_action = self.teleop_right.advance(current_eef_pos=eef_pos_right, current_eef_quat=eef_quat_right)
|
|
|
|
# Joint-locking: only convert when grip active
|
|
if self.teleop_left.grip_active or self._last_root_left is None:
|
|
self._last_root_left = convert_action_world_to_root(left_action, robot)[:7].clone()
|
|
if self.teleop_right.grip_active or self._last_root_right is None:
|
|
self._last_root_right = convert_action_world_to_root(right_action, robot)[:7].clone()
|
|
|
|
# left_arm(7) | wheel(4) | left_grip(1) | right_arm(7) | right_grip(1)
|
|
return torch.cat([
|
|
self._last_root_left, wheel_cmd, left_action[7:8],
|
|
self._last_root_right, right_action[7:8],
|
|
])
|
|
|
|
def post_step(self, obs):
|
|
# Apply direct base velocity override for skid-steer
|
|
robot = self.env.unwrapped.scene["robot"]
|
|
self.chassis.apply_base_velocity(robot, self._v_fwd, self._omega, self.num_envs, self.device)
|
|
|
|
def cleanup(self):
|
|
self.teleop_left.close()
|
|
self.teleop_right.close()
|
|
self.shared_client.close()
|
|
super().cleanup()
|
|
|
|
def _print_banner(self):
|
|
print("\n" + "=" * 50)
|
|
print(" Teleoperation Started!")
|
|
print(" LEFT controller -> left arm")
|
|
print(" RIGHT controller -> right arm")
|
|
print(" TRIGGER: open/close gripper")
|
|
print(" GRIP (hold): move the arm")
|
|
print(" Left joystick: Y=forward/back, X=turn")
|
|
print(" Y (left controller): reset environment")
|
|
print("=" * 50 + "\n")
|