fix: client sends timed objects only, and uses lock to read & write robot status

This commit is contained in:
Francesco Capuano
2025-04-19 14:30:29 +02:00
parent 200ba1feb5
commit b2d003e6eb

View File

@@ -1,12 +1,12 @@
import pickle # nosec
import threading
import time
from queue import Empty, Queue
from typing import Optional
from typing import Any, Optional
import async_inference_pb2 # type: ignore
import async_inference_pb2_grpc # type: ignore
import grpc
import numpy as np
import torch
from lerobot.common.robot_devices.robots.utils import make_robot
@@ -16,28 +16,40 @@ idle_wait = 0.1
class TimedData:
def __init__(self, timestep: int, data: np.ndarray):
self.timestep = timestep
def __init__(self, timestamp: float, data: Any, timestep: int):
"""Initialize a TimedData object.
Args:
timestamp: Unix timestamp relative to data's creation.
data: The actual data to wrap a timestamp around.
"""
self.timestamp = timestamp
self.data = data
self.timestep = timestep
def get_data(self):
return self.data
def get_timestamp(self):
return self.timestamp
def get_timestep(self):
return self.timestep
class TimedAction(TimedData):
def __init__(self, timestep: int, action: np.ndarray):
super().__init__(timestep, action)
def __init__(self, timestamp: float, action: torch.Tensor, timestep: int):
super().__init__(timestamp=timestamp, data=action, timestep=timestep)
def get_action(self):
return self.get_data()
class TimedObservation(TimedData):
def __init__(self, timestep: int, observation: np.ndarray, transfer_state: int = 0):
super().__init__(timestep, observation)
def __init__(
self, timestamp: float, observation: dict[str, torch.Tensor], timestep: int, transfer_state: int = 0
):
super().__init__(timestamp=timestamp, data=observation, timestep=timestep)
self.transfer_state = transfer_state
def get_observation(self):
@@ -60,9 +72,10 @@ class RobotClient:
self.action_chunk_size = 20
self.action_queue = Queue()
self.start_barrier = threading.Barrier(3) # Barrier for 3 threads
self.start_barrier = threading.Barrier(3)
self.observation_timestep = 0
# Create a lock for robot access
self.robot_lock = threading.Lock()
self.use_robot = use_robot
if self.use_robot:
@@ -72,8 +85,10 @@ class RobotClient:
time.sleep(idle_wait) # sleep waiting for cameras to activate
print("Robot connected")
def timesteps(self):
"""Get the timesteps of the actions in the queue"""
self.robot_reading = True
def timestamps(self):
"""Get the timestamps of the actions in the queue"""
return sorted([action.get_timestep() for action in self.action_queue.queue])
def start(self):
@@ -93,11 +108,13 @@ class RobotClient:
def stop(self):
"""Stop the robot client"""
self.running = False
if self.use_robot and hasattr(self, "robot"):
self.robot.disconnect()
self.channel.close()
def send_observation(
self,
observation_data: np.ndarray,
obs: TimedObservation,
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
) -> bool:
"""Send observation to the policy server.
@@ -106,10 +123,10 @@ class RobotClient:
print("Client not running")
return False
# Convert observation data to bytes
observation_data = observation_data.tobytes()
assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!"
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_data)
observation_bytes = pickle.dumps(obs)
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes)
try:
_ = self.stub.SendObservations(iter([observation]))
@@ -121,24 +138,19 @@ class RobotClient:
print(f"Error sending observation: {e}")
return False
def _validate_action(self, action: np.ndarray):
"""Validate the action"""
assert action.shape == (7,), f"Action shape must be (7,) (including timestep), got {action.shape}"
def _validate_action(self, action: TimedAction):
"""Received actions are keps only when they have been produced for now or later, never before"""
return not action.get_timestamp() < self.latest_action
return True
def _validate_action_chunk(self, actions: list[np.ndarray]):
"""Validate the action chunk"""
def _validate_action_chunk(self, actions: list[TimedAction]):
assert len(actions) == self.action_chunk_size, (
f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}"
)
assert all(self._validate_action(action) for action in actions), "Invalid action in chunk"
return True
def _inspect_action_queue(self):
"""Inspect the action queue"""
print("Queue size: ", self.action_queue.qsize())
print("Queue contents: ", sorted([action.get_timestep() for action in self.action_queue.queue]))
@@ -151,13 +163,10 @@ class RobotClient:
break
def _fill_action_queue(self, actions: list[TimedAction]):
"""Fill the action queue with incoming actions"""
"""Fill the action queue with incoming valid actions"""
for action in actions:
# Only keep the actions that are newer than the latest action
if action.get_timestep() <= self.latest_action:
continue
self.action_queue.put(action)
if self._validate_action(action):
self.action_queue.put(action)
def _update_action_queue(self, actions: list[TimedAction]):
"""Aggregate incoming actions into the action queue.
@@ -176,6 +185,7 @@ class RobotClient:
Args:
actions: List of TimedAction instances to queue
"""
print("*** Current latest action: ", self.latest_action, "***")
print("\t**** Current queue content ****: ")
self._inspect_action_queue()
@@ -188,55 +198,22 @@ class RobotClient:
print("\t*** Queue after clearing and filling ****: ")
self._inspect_action_queue()
def _create_timed_actions(self, action_data: np.ndarray) -> list[TimedAction]:
"""Create TimedAction instances from raw action data.
Args:
action_data: Numpy array of shape (chunk_size, 7) where first column
is timestep and remaining columns are action values.
Returns:
List of TimedAction instances
"""
timed_actions = []
for action in action_data:
timestep = int(action[0]) # First element is the timestep
action_values = action[1:] # Remaining elements are the action
timed_actions.append(TimedAction(timestep, action_values))
return timed_actions
def receive_actions(self):
"""Receive actions from the policy server"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
print("Action receiving thread starting")
print(self.timesteps())
while self.running:
try:
# Use StreamActions to get a stream of actions from the server
action_chunks_counter = 0
for action in self.stub.StreamActions(async_inference_pb2.Empty()):
# Read the action data which includes timesteps
# Shape is (chunk_size, 7) where first column is timestep
action_data = np.frombuffer(action.data, dtype=np.float32).reshape(
self.action_chunk_size, 7
)
print("*** Receiving actions ****: ")
# Convert raw action data to TimedAction instances
timed_actions = self._create_timed_actions(action_data)
for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()):
# Deserialize bytes back into list[TimedAction]
timed_actions = pickle.loads(actions_chunk.data) # nosec
# strategy for queue composition is specified in the method
self._clear_and_fill_action_queue(timed_actions)
action_chunks_counter += 1
if action_chunks_counter > 2:
raise ValueError("Too many action chunks received")
except grpc.RpcError as e:
print(f"Error receiving actions: {e}")
time.sleep(idle_wait) # Avoid tight loop on error
@@ -258,16 +235,24 @@ class RobotClient:
while self.running:
# Get the next action from the queue
time.sleep(environment_dt)
timed_action = self._get_next_action()
if timed_action is not None:
self.latest_action = timed_action.get_timestep()
# self.latest_action = timed_action.get_timestep()
self.latest_action = timed_action.get_timestamp()
# Convert action to tensor and send to robot
if self.use_robot:
self.robot.send_action(torch.tensor(timed_action.get_action()))
time.sleep(environment_dt)
# Acquire lock before accessing the robot
if self.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
try:
self.robot.send_action(timed_action.get_action())
finally:
# Always release the lock in a finally block to ensure it's released
self.robot_lock.release()
else:
print("Could not acquire robot lock for action execution, retrying next cycle")
else:
# No action available, wait and retry fetching from queue
@@ -282,8 +267,14 @@ class RobotClient:
first_observation = True
while self.running:
try:
# Get serialized observation bytes from the function
time.sleep(environment_dt)
observation = get_observation_fn()
# Skip if observation is None (couldn't acquire lock)
if observation is None:
continue
# Set appropriate transfer state
if first_observation:
state = async_inference_pb2.TRANSFER_BEGIN
@@ -291,12 +282,6 @@ class RobotClient:
else:
state = async_inference_pb2.TRANSFER_MIDDLE
# Build timestep element in observation
# observation = np.hstack(
# (np.array([self.latest_action]), observation)
# ).astype(np.float32)
time.sleep(environment_dt) # Control the observation sending rate
self.send_observation(observation, state)
except Exception as e:
@@ -315,8 +300,21 @@ def async_client():
if not hasattr(get_observation, "counter"):
get_observation.counter = 0
# Create observation with incrementing first element
observation = np.array([get_observation.counter, 0, 0], dtype=np.float32)
# Acquire lock before accessing the robot
observation_content = None
if client.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
try:
observation_content = client.robot.capture_observation()
finally:
# Always release the lock in a finally block to ensure it's released
client.robot_lock.release()
else:
print("Could not acquire robot lock for observation capture, skipping this cycle")
return None # Return None to indicate no observation was captured
observation = TimedObservation(
timestamp=time.time(), observation=observation_content, timestep=get_observation.counter
)
# Increment counter for next call
get_observation.counter += 1