Add timeout/event logic

This commit is contained in:
Simon Alibert
2025-05-23 16:54:41 +02:00
parent 325f5d72f8
commit b623a15a16
2 changed files with 70 additions and 36 deletions

View File

@@ -16,7 +16,6 @@
import logging
import threading
import time
from pprint import pformat
import serial
@@ -53,8 +52,11 @@ class HomonculusArm(Teleoperator):
"wrist_yaw": MotorNormMode.RANGE_M100_100,
"wrist_pitch": MotorNormMode.RANGE_M100_100,
}
self.thread = threading.Thread(target=self._async_read, daemon=True, name=f"{self} _async_read")
self._lock = threading.Lock()
self._state: dict[str, float] | None = None
self.new_state_event = threading.Event()
self.thread = threading.Thread(target=self._read_loop, daemon=True, name=f"{self} _read_loop")
self.lock = threading.Lock()
@property
def action_features(self) -> dict:
@@ -75,7 +77,11 @@ class HomonculusArm(Teleoperator):
if not self.serial.is_open:
self.serial.open()
self.thread.start()
time.sleep(1) # gives time for the thread to ramp up
# wait for the thread to ramp up & 1st state to be ready
if not self.new_state_event.wait(timeout=2):
raise TimeoutError(f"{self}: Timed out waiting for state after 2s.")
logger.info(f"{self} connected.")
@property
@@ -176,14 +182,24 @@ class HomonculusArm(Teleoperator):
return normalized_values
def _read(self, joints: list[str] | None = None, normalize: bool = True) -> dict[str, int | float]:
def _read(
self, joints: list[str] | None = None, normalize: bool = True, timeout: float = 1
) -> dict[str, int | float]:
"""
Return the most recent (single) values from self.last_d,
optionally applying calibration.
"""
with self._lock:
if not self.new_state_event.wait(timeout=timeout):
raise TimeoutError(f"{self}: Timed out waiting for state after {timeout}s.")
with self.lock:
state = self._state
self.new_state_event.clear()
if state is None:
raise RuntimeError(f"{self} Internal error: Event set but no state available.")
if joints is not None:
state = {k: v for k, v in state.items() if k in joints}
@@ -192,19 +208,19 @@ class HomonculusArm(Teleoperator):
return state
def _async_read(self):
def _read_loop(self):
"""
Continuously read from the serial buffer in its own thread and sends values to the main thread through
a queue.
"""
while True:
if self.serial.in_waiting > 0:
self.serial.flush()
raw_values = self.serial.readline().decode("utf-8").strip().split(" ")
if len(raw_values) != 21: # 16 raw + 5 angle values
continue
try:
if self.serial.in_waiting > 0:
self.serial.flush()
raw_values = self.serial.readline().decode("utf-8").strip().split(" ")
if len(raw_values) != 21: # 16 raw + 5 angle values
continue
try:
joint_angles = {
"shoulder_pitch": int(raw_values[19]),
"shoulder_yaw": int(raw_values[18]),
@@ -214,11 +230,13 @@ class HomonculusArm(Teleoperator):
"wrist_yaw": int(raw_values[1]),
"wrist_pitch": int(raw_values[0]),
}
except ValueError:
continue
with self._lock:
self._state = joint_angles
with self._lock:
self._state = joint_angles
self.new_state_event.set()
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
def get_action(self) -> dict[str, float]:
joint_positions = self._read()

View File

@@ -16,7 +16,6 @@
import logging
import threading
import time
from pprint import pformat
import serial
@@ -63,7 +62,6 @@ class HomonculusGlove(Teleoperator):
"pinky_mcp_flexion": MotorNormMode.RANGE_0_100,
"pinky_dip": MotorNormMode.RANGE_0_100,
}
self.inverted_joints = [
"thumb_cmc",
"index_dip",
@@ -72,9 +70,11 @@ class HomonculusGlove(Teleoperator):
"pinky_mcp_abduction",
"pinky_dip",
]
# self._state = dict.fromkeys(self.joints, 100)
self.thread = threading.Thread(target=self._async_read, daemon=True, name=f"{self} _async_read")
self._lock = threading.Lock()
self._state: dict[str, float] | None = None
self.new_state_event = threading.Event()
self.thread = threading.Thread(target=self._read_loop, daemon=True, name=f"{self} _read_loop")
self.lock = threading.Lock()
@property
def action_features(self) -> dict:
@@ -95,7 +95,11 @@ class HomonculusGlove(Teleoperator):
if not self.serial.is_open:
self.serial.open()
self.thread.start()
time.sleep(0.5) # gives time for the thread to ramp up
# wait for the thread to ramp up & 1st state to be ready
if not self.new_state_event.wait(timeout=2):
raise TimeoutError(f"{self}: Timed out waiting for state after 2s.")
logger.info(f"{self} connected.")
@property
@@ -201,14 +205,24 @@ class HomonculusGlove(Teleoperator):
return normalized_values
def _read(self, joints: list[str] | None = None, normalize: bool = True) -> dict[str, int | float]:
def _read(
self, joints: list[str] | None = None, normalize: bool = True, timeout: float = 1
) -> dict[str, int | float]:
"""
Return the most recent (single) values from self.last_d,
optionally applying calibration.
"""
with self._lock:
if not self.new_state_event.wait(timeout=timeout):
raise TimeoutError(f"{self}: Timed out waiting for state after {timeout}s.")
with self.lock:
state = self._state
self.new_state_event.clear()
if state is None:
raise RuntimeError(f"{self} Internal error: Event set but no state available.")
if joints is not None:
state = {k: v for k, v in state.items() if k in joints}
@@ -217,27 +231,29 @@ class HomonculusGlove(Teleoperator):
return state
def _async_read(self):
def _read_loop(self):
"""
Continuously read from the serial buffer in its own thread and sends values to the main thread through
a queue.
"""
while True:
if self.serial.in_waiting > 0:
self.serial.flush()
positions = self.serial.readline().decode("utf-8").strip().split(" ")
if len(positions) != len(self.joints):
continue
try:
if self.serial.in_waiting > 0:
self.serial.flush()
positions = self.serial.readline().decode("utf-8").strip().split(" ")
if len(positions) != len(self.joints):
continue
try:
joint_positions = {
joint: int(pos) for joint, pos in zip(self.joints, positions, strict=True)
}
except ValueError:
continue
with self._lock:
self._state = joint_positions
with self.lock:
self._state = joint_positions
self.new_state_event.set()
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
def get_action(self) -> dict[str, float]:
joint_positions = self._read()