# Copyright (c) 2025, Yutang Li, SIAT (yt.li2@siat.ac.cn) # SPDX-License-Identifier: BSD-3-Clause """Dual-arm + head tracking + trunk + stereo streaming XR agent.""" from __future__ import annotations import torch from .dual_arm_agent import DualArmXrAgent from .head_tracker import HeadTracker from .streaming import StreamingManager from .diagnostics import DiagnosticsReporter class DualArmHeadXrAgent(DualArmXrAgent): """Extends DualArmXrAgent with head tracking, trunk hold, and VR stereo streaming. Action: left_arm(7) | wheel(4) | left_grip(1) | right_arm(7) | right_grip(1) | head(2) | trunk(1) = 23D """ def __init__(self, env, simulation_app, *, pos_sensitivity: float = 1.0, rot_sensitivity: float = 0.3, base_speed: float = 5.0, base_turn: float = 2.0, drive_speed: float = 0.5, drive_turn: float = 1.5, stream_to: str | None = None, stream_port: int = 12345, stream_bitrate: int = 20_000_000, trunk_target: float = 0.1, debug_viewports: bool = True): super().__init__( env, simulation_app, pos_sensitivity=pos_sensitivity, rot_sensitivity=rot_sensitivity, base_speed=base_speed, base_turn=base_turn, drive_speed=drive_speed, drive_turn=drive_turn, debug_viewports=debug_viewports, ) self.head_tracker = HeadTracker() self.trunk_cmd = torch.tensor([trunk_target], dtype=torch.float32) # Streaming self.streamer: StreamingManager | None = None if stream_to: scene = env.unwrapped.scene self.streamer = StreamingManager(stream_to, stream_port, scene, bitrate=stream_bitrate) # Diagnostics self.diagnostics = DiagnosticsReporter(interval=30, is_dual_arm=True) def on_reset(self): super().on_reset() self.head_tracker.reset() def assemble_action(self, obs) -> torch.Tensor: base_action = super().assemble_action(obs) # Head tracking head_targets = self.head_tracker.get_targets(self.shared_client) head_cmd = torch.tensor(head_targets, dtype=torch.float32) # base(20) | head(2) | trunk(1) return torch.cat([base_action, head_cmd, self.trunk_cmd]) def post_step(self, obs): super().post_step(obs) # Stereo streaming if self.streamer is not None: scene = self.env.unwrapped.scene self.streamer.send(scene, self.sim_frame) # Diagnostics self.diagnostics.report( self.env, obs, self.sim_frame, xr_client=self.shared_client, ) def cleanup(self): if self.streamer is not None: self.streamer.close() super().cleanup()