fix: separate threads for obs streaming, action receiving & execution + action queue reconciliation

This commit is contained in:
Francesco Capuano
2025-04-17 21:09:58 +02:00
parent d40e74f371
commit 0fc9a4341f

View File

@@ -1,67 +1,115 @@
import grpc
import time
import threading
import numpy as np
from concurrent import futures
from queue import Queue, Empty
from typing import Optional, Union
import time
from queue import Empty, Queue
from typing import Optional
import async_inference_pb2 # type: ignore
import async_inference_pb2_grpc # type: ignore
import grpc
import numpy as np
import torch
from lerobot.common.robot_devices.robots.utils import make_robot
environment_dt = 1 / 30
idle_wait = 0.1
class TimedData:
def __init__(self, timestep: int, data: np.ndarray):
self.timestep = timestep
self.data = data
def get_data(self):
return self.data
def get_timestep(self):
return self.timestep
class TimedAction(TimedData):
def __init__(self, timestep: int, action: np.ndarray):
super().__init__(timestep, action)
def get_action(self):
return self.get_data()
class TimedObservation(TimedData):
def __init__(self, timestep: int, observation: np.ndarray, transfer_state: int = 0):
super().__init__(timestep, observation)
self.transfer_state = transfer_state
def get_observation(self):
return self.get_data()
class RobotClient:
def __init__(self, server_address="localhost:50051"):
def __init__(
self,
# cfg: RobotConfig,
server_address="localhost:50051",
use_robot=True,
):
self.channel = grpc.insecure_channel(server_address)
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
self.running = False
self.first_observation_sent = False
self.action_chunk_size = 10
self.latest_action = 0
self.action_chunk_size = 20
self.action_queue = Queue()
self.action_queue_lock = threading.Lock()
self.start_barrier = threading.Barrier(3) # Barrier for 3 threads
self.observation_timestep = 0
self.use_robot = use_robot
if self.use_robot:
self.robot = make_robot("so100")
self.robot.connect()
time.sleep(idle_wait) # sleep waiting for cameras to activate
print("Robot connected")
def timesteps(self):
"""Get the timesteps of the actions in the queue"""
return sorted([action.get_timestep() for action in self.action_queue.queue])
# debugging purposes
self.action_buffer = []
def start(self):
"""Start the robot client and connect to the policy server"""
try:
# client-server handshake
self.stub.Ready(async_inference_pb2.Empty())
print("Connected to policy server")
self.running = True
return True
except grpc.RpcError as e:
print(f"Failed to connect to policy server: {e}")
return False
def stop(self):
"""Stop the robot client"""
self.running = False
self.channel.close()
def send_observation(
self,
observation_data: Union[np.ndarray, bytes],
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE
) -> bool:
"""Send observation to the policy server.
self,
observation_data: np.ndarray,
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
) -> bool:
"""Send observation to the policy server.
Returns True if the observation was sent successfully, False otherwise."""
if not self.running:
print("Client not running")
return False
# Convert observation data to bytes
if not isinstance(observation_data, bytes):
observation_data = np.array(observation_data).tobytes()
observation_data = observation_data.tobytes()
observation = async_inference_pb2.Observation(
transfer_state=transfer_state,
data=observation_data
)
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_data)
try:
_ = self.stub.SendObservations(iter([observation]))
@@ -72,75 +120,170 @@ class RobotClient:
except grpc.RpcError as e:
print(f"Error sending observation: {e}")
return False
def _should_replace_queue(self, percentage_left: float = 0.5) -> bool:
"""Check if we should replace the queue based on consumption rate"""
with self.action_queue_lock:
current_size = self.action_queue.qsize()
return current_size/self.action_chunk_size <= percentage_left
def _clear_and_refill_queue(self, actions: list[np.ndarray]):
"""Clear the existing queue and fill it with new actions"""
assert len(actions) == self.action_chunk_size, \
f"Action batch size must match action chunk!" \
f"size: {len(actions)} != {self.action_chunk_size}"
with self.action_queue_lock:
# Clear the queue
while not self.action_queue.empty():
try:
self.action_queue.get_nowait()
except Empty:
break
# Fill with new actions
for action in actions:
self.action_queue.put(action)
def _validate_action(self, action: np.ndarray):
"""Validate the action"""
assert action.shape == (7,), f"Action shape must be (7,) (including timestep), got {action.shape}"
return True
def _validate_action_chunk(self, actions: list[np.ndarray]):
"""Validate the action chunk"""
assert len(actions) == self.action_chunk_size, (
f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}"
)
assert all(self._validate_action(action) for action in actions), "Invalid action in chunk"
return True
def _inspect_action_queue(self):
"""Inspect the action queue"""
print("Queue size: ", self.action_queue.qsize())
print("Queue contents: ", sorted([action.get_timestep() for action in self.action_queue.queue]))
def _clear_queue(self):
"""Clear the existing queue"""
while not self.action_queue.empty():
try:
self.action_queue.get_nowait()
except Empty:
break
def _fill_action_queue(self, actions: list[TimedAction]):
"""Fill the action queue with incoming actions"""
for action in actions:
# Only keep the actions that are newer than the latest action
if action.get_timestep() <= self.latest_action:
continue
self.action_queue.put(action)
def _update_action_queue(self, actions: list[TimedAction]):
"""Aggregate incoming actions into the action queue.
Raises NotImplementedError as this is not implemented yet.
Args:
actions: List of TimedAction instances to queue
"""
# TODO: Implement this
raise NotImplementedError("Not implemented")
def _clear_and_fill_action_queue(self, actions: list[TimedAction]):
"""Clear the existing queue and fill it with new actions.
This is a higher-level function that combines clearing and filling operations.
Args:
actions: List of TimedAction instances to queue
"""
print("\t**** Current queue content ****: ")
self._inspect_action_queue()
print("\t*** Incoming actions ****: ")
print([a.get_timestep() for a in actions])
self._clear_queue()
self._fill_action_queue(actions)
print("\t*** Queue after clearing and filling ****: ")
self._inspect_action_queue()
def _create_timed_actions(self, action_data: np.ndarray) -> list[TimedAction]:
"""Create TimedAction instances from raw action data.
Args:
action_data: Numpy array of shape (chunk_size, 7) where first column
is timestep and remaining columns are action values.
Returns:
List of TimedAction instances
"""
timed_actions = []
for action in action_data:
timestep = int(action[0]) # First element is the timestep
action_values = action[1:] # Remaining elements are the action
timed_actions.append(TimedAction(timestep, action_values))
return timed_actions
def receive_actions(self):
"""Receive actions from the policy server"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
print("Action receiving thread starting")
print(self.timesteps())
while self.running:
# Wait until first observation is sent
if not self.first_observation_sent:
time.sleep(0.1)
continue
try:
# Use StreamActions to get a stream of actions from the server
action_batch = []
action_chunks_counter = 0
for action in self.stub.StreamActions(async_inference_pb2.Empty()):
# NOTE: reading from buffer with numpy requires reshaping
action_data = np.frombuffer(
action.data, dtype=np.float32
).reshape(self.action_chunk_size, -1)
# Read the action data which includes timesteps
# Shape is (chunk_size, 7) where first column is timestep
action_data = np.frombuffer(action.data, dtype=np.float32).reshape(
self.action_chunk_size, 7
)
print("*** Receiving actions ****: ")
# Convert raw action data to TimedAction instances
timed_actions = self._create_timed_actions(action_data)
# strategy for queue composition is specified in the method
self._clear_and_fill_action_queue(timed_actions)
action_chunks_counter += 1
if action_chunks_counter > 2:
raise ValueError("Too many action chunks received")
for a in action_data:
action_batch.append(a)
# Replace entire queue with new batch of actions
if action_batch and self._should_replace_queue():
self._clear_and_refill_queue(action_batch)
except grpc.RpcError as e:
print(f"Error receiving actions: {e}")
time.sleep(1) # Avoid tight loop on error
def get_next_action(self) -> Optional[np.ndarray]:
time.sleep(idle_wait) # Avoid tight loop on error
def _get_next_action(self) -> Optional[TimedAction]:
"""Get the next action from the queue"""
try:
with self.action_queue_lock:
return self.action_queue.get_nowait()
action = self.action_queue.get_nowait()
return action
except Empty:
return None
def execute_actions(self):
"""Continuously execute actions from the queue"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
print("Action execution thread starting")
while self.running:
# Get the next action from the queue
timed_action = self._get_next_action()
if timed_action is not None:
self.latest_action = timed_action.get_timestep()
# Convert action to tensor and send to robot
if self.use_robot:
self.robot.send_action(torch.tensor(timed_action.get_action()))
time.sleep(environment_dt)
else:
# No action available, wait and retry fetching from queue
time.sleep(idle_wait)
def stream_observations(self, get_observation_fn):
"""Continuously stream observations to the server"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
print("Observation streaming thread starting")
first_observation = True
while self.running:
try:
observation = get_observation_fn()
# Set appropriate transfer state
if first_observation:
state = async_inference_pb2.TRANSFER_BEGIN
@@ -148,53 +291,69 @@ class RobotClient:
else:
state = async_inference_pb2.TRANSFER_MIDDLE
# Build timestep element in observation
# observation = np.hstack(
# (np.array([self.latest_action]), observation)
# ).astype(np.float32)
time.sleep(environment_dt) # Control the observation sending rate
self.send_observation(observation, state)
time.sleep(0.1) # Adjust rate as needed
except Exception as e:
print(f"Error in observation sender: {e}")
time.sleep(1)
time.sleep(idle_wait)
def example_usage():
def async_client():
# Example of how to use the RobotClient
client = RobotClient()
if client.start():
# Function to generate mock observations
def get_mock_observation():
return np.random.randint(0, 10, size=10).astype(np.float32)
def get_observation():
# Create a counter attribute if it doesn't exist
if not hasattr(get_observation, "counter"):
get_observation.counter = 0
# Create observation with incrementing first element
observation = np.array([get_observation.counter, 0, 0], dtype=np.float32)
# Increment counter for next call
get_observation.counter += 1
return observation
print("Starting all threads...")
# Create and start observation sender thread
obs_thread = threading.Thread(
target=client.stream_observations,
args=(get_mock_observation,)
)
obs_thread = threading.Thread(target=client.stream_observations, args=(get_observation,))
obs_thread.daemon = True
obs_thread.start()
# Create and start action receiver thread
action_thread = threading.Thread(target=client.receive_actions)
action_thread.daemon = True
action_thread.start()
action_receiver_thread = threading.Thread(target=client.receive_actions)
action_receiver_thread.daemon = True
# Create action execution thread
action_execution_thread = threading.Thread(target=client.execute_actions)
action_execution_thread.daemon = True
# Start all threads
obs_thread.start()
action_receiver_thread.start()
action_execution_thread.start()
try:
# Main loop - action execution
while True:
print(client.action_queue.qsize())
action = client.get_next_action()
if action is not None:
print(f"Executing action: {action}")
time.sleep(1)
else:
print("No action available")
time.sleep(0.5)
# Main thread just keeps everything alive
while client.running:
time.sleep(idle_wait)
except KeyboardInterrupt:
pass
finally:
client.stop()
print("Client stopped")
if __name__ == "__main__":
example_usage()
async_client()