fix: client sends timed objects only, and uses lock to read & write robot status

This commit is contained in:
Francesco Capuano
2025-04-19 14:30:29 +02:00
parent 200ba1feb5
commit b2d003e6eb

View File

@@ -1,12 +1,12 @@
import pickle # nosec
import threading import threading
import time import time
from queue import Empty, Queue from queue import Empty, Queue
from typing import Optional from typing import Any, Optional
import async_inference_pb2 # type: ignore import async_inference_pb2 # type: ignore
import async_inference_pb2_grpc # type: ignore import async_inference_pb2_grpc # type: ignore
import grpc import grpc
import numpy as np
import torch import torch
from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.common.robot_devices.robots.utils import make_robot
@@ -16,28 +16,40 @@ idle_wait = 0.1
class TimedData: class TimedData:
def __init__(self, timestep: int, data: np.ndarray): def __init__(self, timestamp: float, data: Any, timestep: int):
self.timestep = timestep """Initialize a TimedData object.
Args:
timestamp: Unix timestamp relative to data's creation.
data: The actual data to wrap a timestamp around.
"""
self.timestamp = timestamp
self.data = data self.data = data
self.timestep = timestep
def get_data(self): def get_data(self):
return self.data return self.data
def get_timestamp(self):
return self.timestamp
def get_timestep(self): def get_timestep(self):
return self.timestep return self.timestep
class TimedAction(TimedData): class TimedAction(TimedData):
def __init__(self, timestep: int, action: np.ndarray): def __init__(self, timestamp: float, action: torch.Tensor, timestep: int):
super().__init__(timestep, action) super().__init__(timestamp=timestamp, data=action, timestep=timestep)
def get_action(self): def get_action(self):
return self.get_data() return self.get_data()
class TimedObservation(TimedData): class TimedObservation(TimedData):
def __init__(self, timestep: int, observation: np.ndarray, transfer_state: int = 0): def __init__(
super().__init__(timestep, observation) self, timestamp: float, observation: dict[str, torch.Tensor], timestep: int, transfer_state: int = 0
):
super().__init__(timestamp=timestamp, data=observation, timestep=timestep)
self.transfer_state = transfer_state self.transfer_state = transfer_state
def get_observation(self): def get_observation(self):
@@ -60,9 +72,10 @@ class RobotClient:
self.action_chunk_size = 20 self.action_chunk_size = 20
self.action_queue = Queue() self.action_queue = Queue()
self.start_barrier = threading.Barrier(3) # Barrier for 3 threads self.start_barrier = threading.Barrier(3)
self.observation_timestep = 0 # Create a lock for robot access
self.robot_lock = threading.Lock()
self.use_robot = use_robot self.use_robot = use_robot
if self.use_robot: if self.use_robot:
@@ -72,8 +85,10 @@ class RobotClient:
time.sleep(idle_wait) # sleep waiting for cameras to activate time.sleep(idle_wait) # sleep waiting for cameras to activate
print("Robot connected") print("Robot connected")
def timesteps(self): self.robot_reading = True
"""Get the timesteps of the actions in the queue"""
def timestamps(self):
"""Get the timestamps of the actions in the queue"""
return sorted([action.get_timestep() for action in self.action_queue.queue]) return sorted([action.get_timestep() for action in self.action_queue.queue])
def start(self): def start(self):
@@ -93,11 +108,13 @@ class RobotClient:
def stop(self): def stop(self):
"""Stop the robot client""" """Stop the robot client"""
self.running = False self.running = False
if self.use_robot and hasattr(self, "robot"):
self.robot.disconnect()
self.channel.close() self.channel.close()
def send_observation( def send_observation(
self, self,
observation_data: np.ndarray, obs: TimedObservation,
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE, transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
) -> bool: ) -> bool:
"""Send observation to the policy server. """Send observation to the policy server.
@@ -106,10 +123,10 @@ class RobotClient:
print("Client not running") print("Client not running")
return False return False
# Convert observation data to bytes assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!"
observation_data = observation_data.tobytes()
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_data) observation_bytes = pickle.dumps(obs)
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes)
try: try:
_ = self.stub.SendObservations(iter([observation])) _ = self.stub.SendObservations(iter([observation]))
@@ -121,24 +138,19 @@ class RobotClient:
print(f"Error sending observation: {e}") print(f"Error sending observation: {e}")
return False return False
def _validate_action(self, action: np.ndarray): def _validate_action(self, action: TimedAction):
"""Validate the action""" """Received actions are keps only when they have been produced for now or later, never before"""
assert action.shape == (7,), f"Action shape must be (7,) (including timestep), got {action.shape}" return not action.get_timestamp() < self.latest_action
return True def _validate_action_chunk(self, actions: list[TimedAction]):
def _validate_action_chunk(self, actions: list[np.ndarray]):
"""Validate the action chunk"""
assert len(actions) == self.action_chunk_size, ( assert len(actions) == self.action_chunk_size, (
f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}" f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}"
) )
assert all(self._validate_action(action) for action in actions), "Invalid action in chunk" assert all(self._validate_action(action) for action in actions), "Invalid action in chunk"
return True return True
def _inspect_action_queue(self): def _inspect_action_queue(self):
"""Inspect the action queue"""
print("Queue size: ", self.action_queue.qsize()) print("Queue size: ", self.action_queue.qsize())
print("Queue contents: ", sorted([action.get_timestep() for action in self.action_queue.queue])) print("Queue contents: ", sorted([action.get_timestep() for action in self.action_queue.queue]))
@@ -151,13 +163,10 @@ class RobotClient:
break break
def _fill_action_queue(self, actions: list[TimedAction]): def _fill_action_queue(self, actions: list[TimedAction]):
"""Fill the action queue with incoming actions""" """Fill the action queue with incoming valid actions"""
for action in actions: for action in actions:
# Only keep the actions that are newer than the latest action if self._validate_action(action):
if action.get_timestep() <= self.latest_action: self.action_queue.put(action)
continue
self.action_queue.put(action)
def _update_action_queue(self, actions: list[TimedAction]): def _update_action_queue(self, actions: list[TimedAction]):
"""Aggregate incoming actions into the action queue. """Aggregate incoming actions into the action queue.
@@ -176,6 +185,7 @@ class RobotClient:
Args: Args:
actions: List of TimedAction instances to queue actions: List of TimedAction instances to queue
""" """
print("*** Current latest action: ", self.latest_action, "***")
print("\t**** Current queue content ****: ") print("\t**** Current queue content ****: ")
self._inspect_action_queue() self._inspect_action_queue()
@@ -188,55 +198,22 @@ class RobotClient:
print("\t*** Queue after clearing and filling ****: ") print("\t*** Queue after clearing and filling ****: ")
self._inspect_action_queue() self._inspect_action_queue()
def _create_timed_actions(self, action_data: np.ndarray) -> list[TimedAction]:
"""Create TimedAction instances from raw action data.
Args:
action_data: Numpy array of shape (chunk_size, 7) where first column
is timestep and remaining columns are action values.
Returns:
List of TimedAction instances
"""
timed_actions = []
for action in action_data:
timestep = int(action[0]) # First element is the timestep
action_values = action[1:] # Remaining elements are the action
timed_actions.append(TimedAction(timestep, action_values))
return timed_actions
def receive_actions(self): def receive_actions(self):
"""Receive actions from the policy server""" """Receive actions from the policy server"""
# Wait at barrier for synchronized start # Wait at barrier for synchronized start
self.start_barrier.wait() self.start_barrier.wait()
print("Action receiving thread starting") print("Action receiving thread starting")
print(self.timesteps())
while self.running: while self.running:
try: try:
# Use StreamActions to get a stream of actions from the server # Use StreamActions to get a stream of actions from the server
action_chunks_counter = 0 for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()):
for action in self.stub.StreamActions(async_inference_pb2.Empty()): # Deserialize bytes back into list[TimedAction]
# Read the action data which includes timesteps timed_actions = pickle.loads(actions_chunk.data) # nosec
# Shape is (chunk_size, 7) where first column is timestep
action_data = np.frombuffer(action.data, dtype=np.float32).reshape(
self.action_chunk_size, 7
)
print("*** Receiving actions ****: ")
# Convert raw action data to TimedAction instances
timed_actions = self._create_timed_actions(action_data)
# strategy for queue composition is specified in the method # strategy for queue composition is specified in the method
self._clear_and_fill_action_queue(timed_actions) self._clear_and_fill_action_queue(timed_actions)
action_chunks_counter += 1
if action_chunks_counter > 2:
raise ValueError("Too many action chunks received")
except grpc.RpcError as e: except grpc.RpcError as e:
print(f"Error receiving actions: {e}") print(f"Error receiving actions: {e}")
time.sleep(idle_wait) # Avoid tight loop on error time.sleep(idle_wait) # Avoid tight loop on error
@@ -258,16 +235,24 @@ class RobotClient:
while self.running: while self.running:
# Get the next action from the queue # Get the next action from the queue
time.sleep(environment_dt)
timed_action = self._get_next_action() timed_action = self._get_next_action()
if timed_action is not None: if timed_action is not None:
self.latest_action = timed_action.get_timestep() # self.latest_action = timed_action.get_timestep()
self.latest_action = timed_action.get_timestamp()
# Convert action to tensor and send to robot # Convert action to tensor and send to robot
if self.use_robot: if self.use_robot:
self.robot.send_action(torch.tensor(timed_action.get_action())) # Acquire lock before accessing the robot
if self.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
time.sleep(environment_dt) try:
self.robot.send_action(timed_action.get_action())
finally:
# Always release the lock in a finally block to ensure it's released
self.robot_lock.release()
else:
print("Could not acquire robot lock for action execution, retrying next cycle")
else: else:
# No action available, wait and retry fetching from queue # No action available, wait and retry fetching from queue
@@ -282,8 +267,14 @@ class RobotClient:
first_observation = True first_observation = True
while self.running: while self.running:
try: try:
# Get serialized observation bytes from the function
time.sleep(environment_dt)
observation = get_observation_fn() observation = get_observation_fn()
# Skip if observation is None (couldn't acquire lock)
if observation is None:
continue
# Set appropriate transfer state # Set appropriate transfer state
if first_observation: if first_observation:
state = async_inference_pb2.TRANSFER_BEGIN state = async_inference_pb2.TRANSFER_BEGIN
@@ -291,12 +282,6 @@ class RobotClient:
else: else:
state = async_inference_pb2.TRANSFER_MIDDLE state = async_inference_pb2.TRANSFER_MIDDLE
# Build timestep element in observation
# observation = np.hstack(
# (np.array([self.latest_action]), observation)
# ).astype(np.float32)
time.sleep(environment_dt) # Control the observation sending rate
self.send_observation(observation, state) self.send_observation(observation, state)
except Exception as e: except Exception as e:
@@ -315,8 +300,21 @@ def async_client():
if not hasattr(get_observation, "counter"): if not hasattr(get_observation, "counter"):
get_observation.counter = 0 get_observation.counter = 0
# Create observation with incrementing first element # Acquire lock before accessing the robot
observation = np.array([get_observation.counter, 0, 0], dtype=np.float32) observation_content = None
if client.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
try:
observation_content = client.robot.capture_observation()
finally:
# Always release the lock in a finally block to ensure it's released
client.robot_lock.release()
else:
print("Could not acquire robot lock for observation capture, skipping this cycle")
return None # Return None to indicate no observation was captured
observation = TimedObservation(
timestamp=time.time(), observation=observation_content, timestep=get_observation.counter
)
# Increment counter for next call # Increment counter for next call
get_observation.counter += 1 get_observation.counter += 1