diff --git a/lerobot/scripts/server/robot_client.py b/lerobot/scripts/server/robot_client.py index 398fe8d4..165c7ec2 100644 --- a/lerobot/scripts/server/robot_client.py +++ b/lerobot/scripts/server/robot_client.py @@ -1,12 +1,12 @@ +import pickle # nosec import threading import time from queue import Empty, Queue -from typing import Optional +from typing import Any, Optional import async_inference_pb2 # type: ignore import async_inference_pb2_grpc # type: ignore import grpc -import numpy as np import torch from lerobot.common.robot_devices.robots.utils import make_robot @@ -16,28 +16,40 @@ idle_wait = 0.1 class TimedData: - def __init__(self, timestep: int, data: np.ndarray): - self.timestep = timestep + def __init__(self, timestamp: float, data: Any, timestep: int): + """Initialize a TimedData object. + + Args: + timestamp: Unix timestamp relative to data's creation. + data: The actual data to wrap a timestamp around. + """ + self.timestamp = timestamp self.data = data + self.timestep = timestep def get_data(self): return self.data + def get_timestamp(self): + return self.timestamp + def get_timestep(self): return self.timestep class TimedAction(TimedData): - def __init__(self, timestep: int, action: np.ndarray): - super().__init__(timestep, action) + def __init__(self, timestamp: float, action: torch.Tensor, timestep: int): + super().__init__(timestamp=timestamp, data=action, timestep=timestep) def get_action(self): return self.get_data() class TimedObservation(TimedData): - def __init__(self, timestep: int, observation: np.ndarray, transfer_state: int = 0): - super().__init__(timestep, observation) + def __init__( + self, timestamp: float, observation: dict[str, torch.Tensor], timestep: int, transfer_state: int = 0 + ): + super().__init__(timestamp=timestamp, data=observation, timestep=timestep) self.transfer_state = transfer_state def get_observation(self): @@ -60,9 +72,10 @@ class RobotClient: self.action_chunk_size = 20 self.action_queue = Queue() - self.start_barrier = threading.Barrier(3) # Barrier for 3 threads + self.start_barrier = threading.Barrier(3) - self.observation_timestep = 0 + # Create a lock for robot access + self.robot_lock = threading.Lock() self.use_robot = use_robot if self.use_robot: @@ -72,8 +85,10 @@ class RobotClient: time.sleep(idle_wait) # sleep waiting for cameras to activate print("Robot connected") - def timesteps(self): - """Get the timesteps of the actions in the queue""" + self.robot_reading = True + + def timestamps(self): + """Get the timestamps of the actions in the queue""" return sorted([action.get_timestep() for action in self.action_queue.queue]) def start(self): @@ -93,11 +108,13 @@ class RobotClient: def stop(self): """Stop the robot client""" self.running = False + if self.use_robot and hasattr(self, "robot"): + self.robot.disconnect() self.channel.close() def send_observation( self, - observation_data: np.ndarray, + obs: TimedObservation, transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE, ) -> bool: """Send observation to the policy server. @@ -106,10 +123,10 @@ class RobotClient: print("Client not running") return False - # Convert observation data to bytes - observation_data = observation_data.tobytes() + assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!" - observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_data) + observation_bytes = pickle.dumps(obs) + observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes) try: _ = self.stub.SendObservations(iter([observation])) @@ -121,24 +138,19 @@ class RobotClient: print(f"Error sending observation: {e}") return False - def _validate_action(self, action: np.ndarray): - """Validate the action""" - assert action.shape == (7,), f"Action shape must be (7,) (including timestep), got {action.shape}" + def _validate_action(self, action: TimedAction): + """Received actions are keps only when they have been produced for now or later, never before""" + return not action.get_timestamp() < self.latest_action - return True - - def _validate_action_chunk(self, actions: list[np.ndarray]): - """Validate the action chunk""" + def _validate_action_chunk(self, actions: list[TimedAction]): assert len(actions) == self.action_chunk_size, ( f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}" ) - assert all(self._validate_action(action) for action in actions), "Invalid action in chunk" return True def _inspect_action_queue(self): - """Inspect the action queue""" print("Queue size: ", self.action_queue.qsize()) print("Queue contents: ", sorted([action.get_timestep() for action in self.action_queue.queue])) @@ -151,13 +163,10 @@ class RobotClient: break def _fill_action_queue(self, actions: list[TimedAction]): - """Fill the action queue with incoming actions""" + """Fill the action queue with incoming valid actions""" for action in actions: - # Only keep the actions that are newer than the latest action - if action.get_timestep() <= self.latest_action: - continue - - self.action_queue.put(action) + if self._validate_action(action): + self.action_queue.put(action) def _update_action_queue(self, actions: list[TimedAction]): """Aggregate incoming actions into the action queue. @@ -176,6 +185,7 @@ class RobotClient: Args: actions: List of TimedAction instances to queue """ + print("*** Current latest action: ", self.latest_action, "***") print("\t**** Current queue content ****: ") self._inspect_action_queue() @@ -188,55 +198,22 @@ class RobotClient: print("\t*** Queue after clearing and filling ****: ") self._inspect_action_queue() - def _create_timed_actions(self, action_data: np.ndarray) -> list[TimedAction]: - """Create TimedAction instances from raw action data. - - Args: - action_data: Numpy array of shape (chunk_size, 7) where first column - is timestep and remaining columns are action values. - - Returns: - List of TimedAction instances - """ - timed_actions = [] - for action in action_data: - timestep = int(action[0]) # First element is the timestep - action_values = action[1:] # Remaining elements are the action - timed_actions.append(TimedAction(timestep, action_values)) - - return timed_actions - def receive_actions(self): """Receive actions from the policy server""" # Wait at barrier for synchronized start self.start_barrier.wait() print("Action receiving thread starting") - print(self.timesteps()) - while self.running: try: # Use StreamActions to get a stream of actions from the server - action_chunks_counter = 0 - for action in self.stub.StreamActions(async_inference_pb2.Empty()): - # Read the action data which includes timesteps - # Shape is (chunk_size, 7) where first column is timestep - action_data = np.frombuffer(action.data, dtype=np.float32).reshape( - self.action_chunk_size, 7 - ) - - print("*** Receiving actions ****: ") - # Convert raw action data to TimedAction instances - timed_actions = self._create_timed_actions(action_data) + for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()): + # Deserialize bytes back into list[TimedAction] + timed_actions = pickle.loads(actions_chunk.data) # nosec # strategy for queue composition is specified in the method self._clear_and_fill_action_queue(timed_actions) - action_chunks_counter += 1 - - if action_chunks_counter > 2: - raise ValueError("Too many action chunks received") - except grpc.RpcError as e: print(f"Error receiving actions: {e}") time.sleep(idle_wait) # Avoid tight loop on error @@ -258,16 +235,24 @@ class RobotClient: while self.running: # Get the next action from the queue + time.sleep(environment_dt) timed_action = self._get_next_action() if timed_action is not None: - self.latest_action = timed_action.get_timestep() + # self.latest_action = timed_action.get_timestep() + self.latest_action = timed_action.get_timestamp() # Convert action to tensor and send to robot if self.use_robot: - self.robot.send_action(torch.tensor(timed_action.get_action())) - - time.sleep(environment_dt) + # Acquire lock before accessing the robot + if self.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock + try: + self.robot.send_action(timed_action.get_action()) + finally: + # Always release the lock in a finally block to ensure it's released + self.robot_lock.release() + else: + print("Could not acquire robot lock for action execution, retrying next cycle") else: # No action available, wait and retry fetching from queue @@ -282,8 +267,14 @@ class RobotClient: first_observation = True while self.running: try: + # Get serialized observation bytes from the function + time.sleep(environment_dt) observation = get_observation_fn() + # Skip if observation is None (couldn't acquire lock) + if observation is None: + continue + # Set appropriate transfer state if first_observation: state = async_inference_pb2.TRANSFER_BEGIN @@ -291,12 +282,6 @@ class RobotClient: else: state = async_inference_pb2.TRANSFER_MIDDLE - # Build timestep element in observation - # observation = np.hstack( - # (np.array([self.latest_action]), observation) - # ).astype(np.float32) - - time.sleep(environment_dt) # Control the observation sending rate self.send_observation(observation, state) except Exception as e: @@ -315,8 +300,21 @@ def async_client(): if not hasattr(get_observation, "counter"): get_observation.counter = 0 - # Create observation with incrementing first element - observation = np.array([get_observation.counter, 0, 0], dtype=np.float32) + # Acquire lock before accessing the robot + observation_content = None + if client.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock + try: + observation_content = client.robot.capture_observation() + finally: + # Always release the lock in a finally block to ensure it's released + client.robot_lock.release() + else: + print("Could not acquire robot lock for observation capture, skipping this cycle") + return None # Return None to indicate no observation was captured + + observation = TimedObservation( + timestamp=time.time(), observation=observation_content, timestep=get_observation.counter + ) # Increment counter for next call get_observation.counter += 1