fix: separate threads for obs streaming, action receiving & execution + action queue reconciliation
This commit is contained in:
@@ -1,67 +1,115 @@
|
|||||||
import grpc
|
|
||||||
import time
|
|
||||||
import threading
|
import threading
|
||||||
import numpy as np
|
import time
|
||||||
from concurrent import futures
|
from queue import Empty, Queue
|
||||||
from queue import Queue, Empty
|
from typing import Optional
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import async_inference_pb2 # type: ignore
|
import async_inference_pb2 # type: ignore
|
||||||
import async_inference_pb2_grpc # type: ignore
|
import async_inference_pb2_grpc # type: ignore
|
||||||
|
import grpc
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||||
|
|
||||||
|
environment_dt = 1 / 30
|
||||||
|
idle_wait = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class TimedData:
|
||||||
|
def __init__(self, timestep: int, data: np.ndarray):
|
||||||
|
self.timestep = timestep
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def get_data(self):
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
def get_timestep(self):
|
||||||
|
return self.timestep
|
||||||
|
|
||||||
|
|
||||||
|
class TimedAction(TimedData):
|
||||||
|
def __init__(self, timestep: int, action: np.ndarray):
|
||||||
|
super().__init__(timestep, action)
|
||||||
|
|
||||||
|
def get_action(self):
|
||||||
|
return self.get_data()
|
||||||
|
|
||||||
|
|
||||||
|
class TimedObservation(TimedData):
|
||||||
|
def __init__(self, timestep: int, observation: np.ndarray, transfer_state: int = 0):
|
||||||
|
super().__init__(timestep, observation)
|
||||||
|
self.transfer_state = transfer_state
|
||||||
|
|
||||||
|
def get_observation(self):
|
||||||
|
return self.get_data()
|
||||||
|
|
||||||
|
|
||||||
class RobotClient:
|
class RobotClient:
|
||||||
def __init__(self, server_address="localhost:50051"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
# cfg: RobotConfig,
|
||||||
|
server_address="localhost:50051",
|
||||||
|
use_robot=True,
|
||||||
|
):
|
||||||
self.channel = grpc.insecure_channel(server_address)
|
self.channel = grpc.insecure_channel(server_address)
|
||||||
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||||
|
|
||||||
self.running = False
|
self.running = False
|
||||||
self.first_observation_sent = False
|
self.first_observation_sent = False
|
||||||
self.action_chunk_size = 10
|
self.latest_action = 0
|
||||||
|
self.action_chunk_size = 20
|
||||||
|
|
||||||
self.action_queue = Queue()
|
self.action_queue = Queue()
|
||||||
self.action_queue_lock = threading.Lock()
|
self.start_barrier = threading.Barrier(3) # Barrier for 3 threads
|
||||||
|
|
||||||
|
self.observation_timestep = 0
|
||||||
|
|
||||||
|
self.use_robot = use_robot
|
||||||
|
if self.use_robot:
|
||||||
|
self.robot = make_robot("so100")
|
||||||
|
self.robot.connect()
|
||||||
|
|
||||||
|
time.sleep(idle_wait) # sleep waiting for cameras to activate
|
||||||
|
print("Robot connected")
|
||||||
|
|
||||||
|
def timesteps(self):
|
||||||
|
"""Get the timesteps of the actions in the queue"""
|
||||||
|
return sorted([action.get_timestep() for action in self.action_queue.queue])
|
||||||
|
|
||||||
# debugging purposes
|
|
||||||
self.action_buffer = []
|
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Start the robot client and connect to the policy server"""
|
"""Start the robot client and connect to the policy server"""
|
||||||
try:
|
try:
|
||||||
# client-server handshake
|
# client-server handshake
|
||||||
self.stub.Ready(async_inference_pb2.Empty())
|
self.stub.Ready(async_inference_pb2.Empty())
|
||||||
print("Connected to policy server")
|
print("Connected to policy server")
|
||||||
|
|
||||||
self.running = True
|
self.running = True
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except grpc.RpcError as e:
|
except grpc.RpcError as e:
|
||||||
print(f"Failed to connect to policy server: {e}")
|
print(f"Failed to connect to policy server: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Stop the robot client"""
|
"""Stop the robot client"""
|
||||||
self.running = False
|
self.running = False
|
||||||
self.channel.close()
|
self.channel.close()
|
||||||
|
|
||||||
def send_observation(
|
def send_observation(
|
||||||
self,
|
self,
|
||||||
observation_data: Union[np.ndarray, bytes],
|
observation_data: np.ndarray,
|
||||||
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE
|
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Send observation to the policy server.
|
"""Send observation to the policy server.
|
||||||
Returns True if the observation was sent successfully, False otherwise."""
|
Returns True if the observation was sent successfully, False otherwise."""
|
||||||
if not self.running:
|
if not self.running:
|
||||||
print("Client not running")
|
print("Client not running")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Convert observation data to bytes
|
# Convert observation data to bytes
|
||||||
if not isinstance(observation_data, bytes):
|
observation_data = observation_data.tobytes()
|
||||||
observation_data = np.array(observation_data).tobytes()
|
|
||||||
|
|
||||||
observation = async_inference_pb2.Observation(
|
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_data)
|
||||||
transfer_state=transfer_state,
|
|
||||||
data=observation_data
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_ = self.stub.SendObservations(iter([observation]))
|
_ = self.stub.SendObservations(iter([observation]))
|
||||||
@@ -72,75 +120,170 @@ class RobotClient:
|
|||||||
except grpc.RpcError as e:
|
except grpc.RpcError as e:
|
||||||
print(f"Error sending observation: {e}")
|
print(f"Error sending observation: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _should_replace_queue(self, percentage_left: float = 0.5) -> bool:
|
|
||||||
"""Check if we should replace the queue based on consumption rate"""
|
|
||||||
with self.action_queue_lock:
|
|
||||||
current_size = self.action_queue.qsize()
|
|
||||||
return current_size/self.action_chunk_size <= percentage_left
|
|
||||||
|
|
||||||
def _clear_and_refill_queue(self, actions: list[np.ndarray]):
|
def _validate_action(self, action: np.ndarray):
|
||||||
"""Clear the existing queue and fill it with new actions"""
|
"""Validate the action"""
|
||||||
assert len(actions) == self.action_chunk_size, \
|
assert action.shape == (7,), f"Action shape must be (7,) (including timestep), got {action.shape}"
|
||||||
f"Action batch size must match action chunk!" \
|
|
||||||
f"size: {len(actions)} != {self.action_chunk_size}"
|
|
||||||
|
|
||||||
with self.action_queue_lock:
|
|
||||||
# Clear the queue
|
|
||||||
while not self.action_queue.empty():
|
|
||||||
try:
|
|
||||||
self.action_queue.get_nowait()
|
|
||||||
except Empty:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Fill with new actions
|
|
||||||
for action in actions:
|
|
||||||
self.action_queue.put(action)
|
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _validate_action_chunk(self, actions: list[np.ndarray]):
|
||||||
|
"""Validate the action chunk"""
|
||||||
|
assert len(actions) == self.action_chunk_size, (
|
||||||
|
f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all(self._validate_action(action) for action in actions), "Invalid action in chunk"
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _inspect_action_queue(self):
|
||||||
|
"""Inspect the action queue"""
|
||||||
|
print("Queue size: ", self.action_queue.qsize())
|
||||||
|
print("Queue contents: ", sorted([action.get_timestep() for action in self.action_queue.queue]))
|
||||||
|
|
||||||
|
def _clear_queue(self):
|
||||||
|
"""Clear the existing queue"""
|
||||||
|
while not self.action_queue.empty():
|
||||||
|
try:
|
||||||
|
self.action_queue.get_nowait()
|
||||||
|
except Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
def _fill_action_queue(self, actions: list[TimedAction]):
|
||||||
|
"""Fill the action queue with incoming actions"""
|
||||||
|
for action in actions:
|
||||||
|
# Only keep the actions that are newer than the latest action
|
||||||
|
if action.get_timestep() <= self.latest_action:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.action_queue.put(action)
|
||||||
|
|
||||||
|
def _update_action_queue(self, actions: list[TimedAction]):
|
||||||
|
"""Aggregate incoming actions into the action queue.
|
||||||
|
Raises NotImplementedError as this is not implemented yet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actions: List of TimedAction instances to queue
|
||||||
|
"""
|
||||||
|
# TODO: Implement this
|
||||||
|
raise NotImplementedError("Not implemented")
|
||||||
|
|
||||||
|
def _clear_and_fill_action_queue(self, actions: list[TimedAction]):
|
||||||
|
"""Clear the existing queue and fill it with new actions.
|
||||||
|
This is a higher-level function that combines clearing and filling operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actions: List of TimedAction instances to queue
|
||||||
|
"""
|
||||||
|
print("\t**** Current queue content ****: ")
|
||||||
|
self._inspect_action_queue()
|
||||||
|
|
||||||
|
print("\t*** Incoming actions ****: ")
|
||||||
|
print([a.get_timestep() for a in actions])
|
||||||
|
|
||||||
|
self._clear_queue()
|
||||||
|
self._fill_action_queue(actions)
|
||||||
|
|
||||||
|
print("\t*** Queue after clearing and filling ****: ")
|
||||||
|
self._inspect_action_queue()
|
||||||
|
|
||||||
|
def _create_timed_actions(self, action_data: np.ndarray) -> list[TimedAction]:
|
||||||
|
"""Create TimedAction instances from raw action data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_data: Numpy array of shape (chunk_size, 7) where first column
|
||||||
|
is timestep and remaining columns are action values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of TimedAction instances
|
||||||
|
"""
|
||||||
|
timed_actions = []
|
||||||
|
for action in action_data:
|
||||||
|
timestep = int(action[0]) # First element is the timestep
|
||||||
|
action_values = action[1:] # Remaining elements are the action
|
||||||
|
timed_actions.append(TimedAction(timestep, action_values))
|
||||||
|
|
||||||
|
return timed_actions
|
||||||
|
|
||||||
def receive_actions(self):
|
def receive_actions(self):
|
||||||
"""Receive actions from the policy server"""
|
"""Receive actions from the policy server"""
|
||||||
|
# Wait at barrier for synchronized start
|
||||||
|
self.start_barrier.wait()
|
||||||
|
print("Action receiving thread starting")
|
||||||
|
|
||||||
|
print(self.timesteps())
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
# Wait until first observation is sent
|
|
||||||
if not self.first_observation_sent:
|
|
||||||
time.sleep(0.1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use StreamActions to get a stream of actions from the server
|
# Use StreamActions to get a stream of actions from the server
|
||||||
action_batch = []
|
action_chunks_counter = 0
|
||||||
for action in self.stub.StreamActions(async_inference_pb2.Empty()):
|
for action in self.stub.StreamActions(async_inference_pb2.Empty()):
|
||||||
# NOTE: reading from buffer with numpy requires reshaping
|
# Read the action data which includes timesteps
|
||||||
action_data = np.frombuffer(
|
# Shape is (chunk_size, 7) where first column is timestep
|
||||||
action.data, dtype=np.float32
|
action_data = np.frombuffer(action.data, dtype=np.float32).reshape(
|
||||||
).reshape(self.action_chunk_size, -1)
|
self.action_chunk_size, 7
|
||||||
|
)
|
||||||
|
|
||||||
|
print("*** Receiving actions ****: ")
|
||||||
|
# Convert raw action data to TimedAction instances
|
||||||
|
timed_actions = self._create_timed_actions(action_data)
|
||||||
|
|
||||||
|
# strategy for queue composition is specified in the method
|
||||||
|
self._clear_and_fill_action_queue(timed_actions)
|
||||||
|
|
||||||
|
action_chunks_counter += 1
|
||||||
|
|
||||||
|
if action_chunks_counter > 2:
|
||||||
|
raise ValueError("Too many action chunks received")
|
||||||
|
|
||||||
for a in action_data:
|
|
||||||
action_batch.append(a)
|
|
||||||
|
|
||||||
# Replace entire queue with new batch of actions
|
|
||||||
if action_batch and self._should_replace_queue():
|
|
||||||
self._clear_and_refill_queue(action_batch)
|
|
||||||
|
|
||||||
except grpc.RpcError as e:
|
except grpc.RpcError as e:
|
||||||
print(f"Error receiving actions: {e}")
|
print(f"Error receiving actions: {e}")
|
||||||
time.sleep(1) # Avoid tight loop on error
|
time.sleep(idle_wait) # Avoid tight loop on error
|
||||||
|
|
||||||
def get_next_action(self) -> Optional[np.ndarray]:
|
def _get_next_action(self) -> Optional[TimedAction]:
|
||||||
"""Get the next action from the queue"""
|
"""Get the next action from the queue"""
|
||||||
try:
|
try:
|
||||||
with self.action_queue_lock:
|
action = self.action_queue.get_nowait()
|
||||||
return self.action_queue.get_nowait()
|
return action
|
||||||
|
|
||||||
except Empty:
|
except Empty:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def execute_actions(self):
|
||||||
|
"""Continuously execute actions from the queue"""
|
||||||
|
# Wait at barrier for synchronized start
|
||||||
|
self.start_barrier.wait()
|
||||||
|
print("Action execution thread starting")
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
# Get the next action from the queue
|
||||||
|
timed_action = self._get_next_action()
|
||||||
|
|
||||||
|
if timed_action is not None:
|
||||||
|
self.latest_action = timed_action.get_timestep()
|
||||||
|
|
||||||
|
# Convert action to tensor and send to robot
|
||||||
|
if self.use_robot:
|
||||||
|
self.robot.send_action(torch.tensor(timed_action.get_action()))
|
||||||
|
|
||||||
|
time.sleep(environment_dt)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# No action available, wait and retry fetching from queue
|
||||||
|
time.sleep(idle_wait)
|
||||||
|
|
||||||
def stream_observations(self, get_observation_fn):
|
def stream_observations(self, get_observation_fn):
|
||||||
"""Continuously stream observations to the server"""
|
"""Continuously stream observations to the server"""
|
||||||
|
# Wait at barrier for synchronized start
|
||||||
|
self.start_barrier.wait()
|
||||||
|
print("Observation streaming thread starting")
|
||||||
|
|
||||||
first_observation = True
|
first_observation = True
|
||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
observation = get_observation_fn()
|
observation = get_observation_fn()
|
||||||
|
|
||||||
# Set appropriate transfer state
|
# Set appropriate transfer state
|
||||||
if first_observation:
|
if first_observation:
|
||||||
state = async_inference_pb2.TRANSFER_BEGIN
|
state = async_inference_pb2.TRANSFER_BEGIN
|
||||||
@@ -148,53 +291,69 @@ class RobotClient:
|
|||||||
else:
|
else:
|
||||||
state = async_inference_pb2.TRANSFER_MIDDLE
|
state = async_inference_pb2.TRANSFER_MIDDLE
|
||||||
|
|
||||||
|
# Build timestep element in observation
|
||||||
|
# observation = np.hstack(
|
||||||
|
# (np.array([self.latest_action]), observation)
|
||||||
|
# ).astype(np.float32)
|
||||||
|
|
||||||
|
time.sleep(environment_dt) # Control the observation sending rate
|
||||||
self.send_observation(observation, state)
|
self.send_observation(observation, state)
|
||||||
time.sleep(0.1) # Adjust rate as needed
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in observation sender: {e}")
|
print(f"Error in observation sender: {e}")
|
||||||
time.sleep(1)
|
time.sleep(idle_wait)
|
||||||
|
|
||||||
def example_usage():
|
|
||||||
|
def async_client():
|
||||||
# Example of how to use the RobotClient
|
# Example of how to use the RobotClient
|
||||||
client = RobotClient()
|
client = RobotClient()
|
||||||
|
|
||||||
if client.start():
|
if client.start():
|
||||||
# Function to generate mock observations
|
# Function to generate mock observations
|
||||||
def get_mock_observation():
|
def get_observation():
|
||||||
return np.random.randint(0, 10, size=10).astype(np.float32)
|
# Create a counter attribute if it doesn't exist
|
||||||
|
if not hasattr(get_observation, "counter"):
|
||||||
|
get_observation.counter = 0
|
||||||
|
|
||||||
|
# Create observation with incrementing first element
|
||||||
|
observation = np.array([get_observation.counter, 0, 0], dtype=np.float32)
|
||||||
|
|
||||||
|
# Increment counter for next call
|
||||||
|
get_observation.counter += 1
|
||||||
|
|
||||||
|
return observation
|
||||||
|
|
||||||
|
print("Starting all threads...")
|
||||||
|
|
||||||
# Create and start observation sender thread
|
# Create and start observation sender thread
|
||||||
obs_thread = threading.Thread(
|
obs_thread = threading.Thread(target=client.stream_observations, args=(get_observation,))
|
||||||
target=client.stream_observations,
|
|
||||||
args=(get_mock_observation,)
|
|
||||||
)
|
|
||||||
obs_thread.daemon = True
|
obs_thread.daemon = True
|
||||||
obs_thread.start()
|
|
||||||
|
|
||||||
# Create and start action receiver thread
|
# Create and start action receiver thread
|
||||||
action_thread = threading.Thread(target=client.receive_actions)
|
action_receiver_thread = threading.Thread(target=client.receive_actions)
|
||||||
action_thread.daemon = True
|
action_receiver_thread.daemon = True
|
||||||
action_thread.start()
|
|
||||||
|
# Create action execution thread
|
||||||
|
action_execution_thread = threading.Thread(target=client.execute_actions)
|
||||||
|
action_execution_thread.daemon = True
|
||||||
|
|
||||||
|
# Start all threads
|
||||||
|
obs_thread.start()
|
||||||
|
action_receiver_thread.start()
|
||||||
|
action_execution_thread.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Main loop - action execution
|
# Main thread just keeps everything alive
|
||||||
while True:
|
while client.running:
|
||||||
print(client.action_queue.qsize())
|
time.sleep(idle_wait)
|
||||||
action = client.get_next_action()
|
|
||||||
if action is not None:
|
|
||||||
print(f"Executing action: {action}")
|
|
||||||
time.sleep(1)
|
|
||||||
else:
|
|
||||||
print("No action available")
|
|
||||||
time.sleep(0.5)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
client.stop()
|
client.stop()
|
||||||
|
print("Client stopped")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
example_usage()
|
async_client()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user