forked from tangger/lerobot
fix: separate threads for obs streaming, action receiving & execution + action queue reconciliation
This commit is contained in:
@@ -1,67 +1,115 @@
|
||||
import grpc
|
||||
import time
|
||||
import threading
|
||||
import numpy as np
|
||||
from concurrent import futures
|
||||
from queue import Queue, Empty
|
||||
from typing import Optional, Union
|
||||
import time
|
||||
from queue import Empty, Queue
|
||||
from typing import Optional
|
||||
|
||||
import async_inference_pb2 # type: ignore
|
||||
import async_inference_pb2_grpc # type: ignore
|
||||
import grpc
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
|
||||
environment_dt = 1 / 30
|
||||
idle_wait = 0.1
|
||||
|
||||
|
||||
class TimedData:
|
||||
def __init__(self, timestep: int, data: np.ndarray):
|
||||
self.timestep = timestep
|
||||
self.data = data
|
||||
|
||||
def get_data(self):
|
||||
return self.data
|
||||
|
||||
def get_timestep(self):
|
||||
return self.timestep
|
||||
|
||||
|
||||
class TimedAction(TimedData):
|
||||
def __init__(self, timestep: int, action: np.ndarray):
|
||||
super().__init__(timestep, action)
|
||||
|
||||
def get_action(self):
|
||||
return self.get_data()
|
||||
|
||||
|
||||
class TimedObservation(TimedData):
|
||||
def __init__(self, timestep: int, observation: np.ndarray, transfer_state: int = 0):
|
||||
super().__init__(timestep, observation)
|
||||
self.transfer_state = transfer_state
|
||||
|
||||
def get_observation(self):
|
||||
return self.get_data()
|
||||
|
||||
|
||||
class RobotClient:
|
||||
def __init__(self, server_address="localhost:50051"):
|
||||
def __init__(
|
||||
self,
|
||||
# cfg: RobotConfig,
|
||||
server_address="localhost:50051",
|
||||
use_robot=True,
|
||||
):
|
||||
self.channel = grpc.insecure_channel(server_address)
|
||||
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||
|
||||
|
||||
self.running = False
|
||||
self.first_observation_sent = False
|
||||
self.action_chunk_size = 10
|
||||
self.latest_action = 0
|
||||
self.action_chunk_size = 20
|
||||
|
||||
self.action_queue = Queue()
|
||||
self.action_queue_lock = threading.Lock()
|
||||
self.start_barrier = threading.Barrier(3) # Barrier for 3 threads
|
||||
|
||||
self.observation_timestep = 0
|
||||
|
||||
self.use_robot = use_robot
|
||||
if self.use_robot:
|
||||
self.robot = make_robot("so100")
|
||||
self.robot.connect()
|
||||
|
||||
time.sleep(idle_wait) # sleep waiting for cameras to activate
|
||||
print("Robot connected")
|
||||
|
||||
def timesteps(self):
|
||||
"""Get the timesteps of the actions in the queue"""
|
||||
return sorted([action.get_timestep() for action in self.action_queue.queue])
|
||||
|
||||
# debugging purposes
|
||||
self.action_buffer = []
|
||||
|
||||
def start(self):
|
||||
"""Start the robot client and connect to the policy server"""
|
||||
try:
|
||||
# client-server handshake
|
||||
self.stub.Ready(async_inference_pb2.Empty())
|
||||
print("Connected to policy server")
|
||||
|
||||
|
||||
self.running = True
|
||||
return True
|
||||
|
||||
|
||||
except grpc.RpcError as e:
|
||||
print(f"Failed to connect to policy server: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def stop(self):
|
||||
"""Stop the robot client"""
|
||||
self.running = False
|
||||
self.channel.close()
|
||||
|
||||
|
||||
def send_observation(
|
||||
self,
|
||||
observation_data: Union[np.ndarray, bytes],
|
||||
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE
|
||||
) -> bool:
|
||||
"""Send observation to the policy server.
|
||||
self,
|
||||
observation_data: np.ndarray,
|
||||
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
|
||||
) -> bool:
|
||||
"""Send observation to the policy server.
|
||||
Returns True if the observation was sent successfully, False otherwise."""
|
||||
if not self.running:
|
||||
print("Client not running")
|
||||
return False
|
||||
|
||||
# Convert observation data to bytes
|
||||
if not isinstance(observation_data, bytes):
|
||||
observation_data = np.array(observation_data).tobytes()
|
||||
observation_data = observation_data.tobytes()
|
||||
|
||||
observation = async_inference_pb2.Observation(
|
||||
transfer_state=transfer_state,
|
||||
data=observation_data
|
||||
)
|
||||
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_data)
|
||||
|
||||
try:
|
||||
_ = self.stub.SendObservations(iter([observation]))
|
||||
@@ -72,75 +120,170 @@ class RobotClient:
|
||||
except grpc.RpcError as e:
|
||||
print(f"Error sending observation: {e}")
|
||||
return False
|
||||
|
||||
def _should_replace_queue(self, percentage_left: float = 0.5) -> bool:
|
||||
"""Check if we should replace the queue based on consumption rate"""
|
||||
with self.action_queue_lock:
|
||||
current_size = self.action_queue.qsize()
|
||||
return current_size/self.action_chunk_size <= percentage_left
|
||||
|
||||
def _clear_and_refill_queue(self, actions: list[np.ndarray]):
|
||||
"""Clear the existing queue and fill it with new actions"""
|
||||
assert len(actions) == self.action_chunk_size, \
|
||||
f"Action batch size must match action chunk!" \
|
||||
f"size: {len(actions)} != {self.action_chunk_size}"
|
||||
|
||||
with self.action_queue_lock:
|
||||
# Clear the queue
|
||||
while not self.action_queue.empty():
|
||||
try:
|
||||
self.action_queue.get_nowait()
|
||||
except Empty:
|
||||
break
|
||||
|
||||
# Fill with new actions
|
||||
for action in actions:
|
||||
self.action_queue.put(action)
|
||||
def _validate_action(self, action: np.ndarray):
|
||||
"""Validate the action"""
|
||||
assert action.shape == (7,), f"Action shape must be (7,) (including timestep), got {action.shape}"
|
||||
|
||||
return True
|
||||
|
||||
def _validate_action_chunk(self, actions: list[np.ndarray]):
|
||||
"""Validate the action chunk"""
|
||||
assert len(actions) == self.action_chunk_size, (
|
||||
f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}"
|
||||
)
|
||||
|
||||
assert all(self._validate_action(action) for action in actions), "Invalid action in chunk"
|
||||
|
||||
return True
|
||||
|
||||
def _inspect_action_queue(self):
|
||||
"""Inspect the action queue"""
|
||||
print("Queue size: ", self.action_queue.qsize())
|
||||
print("Queue contents: ", sorted([action.get_timestep() for action in self.action_queue.queue]))
|
||||
|
||||
def _clear_queue(self):
|
||||
"""Clear the existing queue"""
|
||||
while not self.action_queue.empty():
|
||||
try:
|
||||
self.action_queue.get_nowait()
|
||||
except Empty:
|
||||
break
|
||||
|
||||
def _fill_action_queue(self, actions: list[TimedAction]):
|
||||
"""Fill the action queue with incoming actions"""
|
||||
for action in actions:
|
||||
# Only keep the actions that are newer than the latest action
|
||||
if action.get_timestep() <= self.latest_action:
|
||||
continue
|
||||
|
||||
self.action_queue.put(action)
|
||||
|
||||
def _update_action_queue(self, actions: list[TimedAction]):
|
||||
"""Aggregate incoming actions into the action queue.
|
||||
Raises NotImplementedError as this is not implemented yet.
|
||||
|
||||
Args:
|
||||
actions: List of TimedAction instances to queue
|
||||
"""
|
||||
# TODO: Implement this
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
def _clear_and_fill_action_queue(self, actions: list[TimedAction]):
|
||||
"""Clear the existing queue and fill it with new actions.
|
||||
This is a higher-level function that combines clearing and filling operations.
|
||||
|
||||
Args:
|
||||
actions: List of TimedAction instances to queue
|
||||
"""
|
||||
print("\t**** Current queue content ****: ")
|
||||
self._inspect_action_queue()
|
||||
|
||||
print("\t*** Incoming actions ****: ")
|
||||
print([a.get_timestep() for a in actions])
|
||||
|
||||
self._clear_queue()
|
||||
self._fill_action_queue(actions)
|
||||
|
||||
print("\t*** Queue after clearing and filling ****: ")
|
||||
self._inspect_action_queue()
|
||||
|
||||
def _create_timed_actions(self, action_data: np.ndarray) -> list[TimedAction]:
|
||||
"""Create TimedAction instances from raw action data.
|
||||
|
||||
Args:
|
||||
action_data: Numpy array of shape (chunk_size, 7) where first column
|
||||
is timestep and remaining columns are action values.
|
||||
|
||||
Returns:
|
||||
List of TimedAction instances
|
||||
"""
|
||||
timed_actions = []
|
||||
for action in action_data:
|
||||
timestep = int(action[0]) # First element is the timestep
|
||||
action_values = action[1:] # Remaining elements are the action
|
||||
timed_actions.append(TimedAction(timestep, action_values))
|
||||
|
||||
return timed_actions
|
||||
|
||||
def receive_actions(self):
|
||||
"""Receive actions from the policy server"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
print("Action receiving thread starting")
|
||||
|
||||
print(self.timesteps())
|
||||
|
||||
while self.running:
|
||||
# Wait until first observation is sent
|
||||
if not self.first_observation_sent:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
try:
|
||||
# Use StreamActions to get a stream of actions from the server
|
||||
action_batch = []
|
||||
action_chunks_counter = 0
|
||||
for action in self.stub.StreamActions(async_inference_pb2.Empty()):
|
||||
# NOTE: reading from buffer with numpy requires reshaping
|
||||
action_data = np.frombuffer(
|
||||
action.data, dtype=np.float32
|
||||
).reshape(self.action_chunk_size, -1)
|
||||
# Read the action data which includes timesteps
|
||||
# Shape is (chunk_size, 7) where first column is timestep
|
||||
action_data = np.frombuffer(action.data, dtype=np.float32).reshape(
|
||||
self.action_chunk_size, 7
|
||||
)
|
||||
|
||||
print("*** Receiving actions ****: ")
|
||||
# Convert raw action data to TimedAction instances
|
||||
timed_actions = self._create_timed_actions(action_data)
|
||||
|
||||
# strategy for queue composition is specified in the method
|
||||
self._clear_and_fill_action_queue(timed_actions)
|
||||
|
||||
action_chunks_counter += 1
|
||||
|
||||
if action_chunks_counter > 2:
|
||||
raise ValueError("Too many action chunks received")
|
||||
|
||||
for a in action_data:
|
||||
action_batch.append(a)
|
||||
|
||||
# Replace entire queue with new batch of actions
|
||||
if action_batch and self._should_replace_queue():
|
||||
self._clear_and_refill_queue(action_batch)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
print(f"Error receiving actions: {e}")
|
||||
time.sleep(1) # Avoid tight loop on error
|
||||
|
||||
def get_next_action(self) -> Optional[np.ndarray]:
|
||||
time.sleep(idle_wait) # Avoid tight loop on error
|
||||
|
||||
def _get_next_action(self) -> Optional[TimedAction]:
|
||||
"""Get the next action from the queue"""
|
||||
try:
|
||||
with self.action_queue_lock:
|
||||
return self.action_queue.get_nowait()
|
||||
action = self.action_queue.get_nowait()
|
||||
return action
|
||||
|
||||
except Empty:
|
||||
return None
|
||||
|
||||
|
||||
def execute_actions(self):
|
||||
"""Continuously execute actions from the queue"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
print("Action execution thread starting")
|
||||
|
||||
while self.running:
|
||||
# Get the next action from the queue
|
||||
timed_action = self._get_next_action()
|
||||
|
||||
if timed_action is not None:
|
||||
self.latest_action = timed_action.get_timestep()
|
||||
|
||||
# Convert action to tensor and send to robot
|
||||
if self.use_robot:
|
||||
self.robot.send_action(torch.tensor(timed_action.get_action()))
|
||||
|
||||
time.sleep(environment_dt)
|
||||
|
||||
else:
|
||||
# No action available, wait and retry fetching from queue
|
||||
time.sleep(idle_wait)
|
||||
|
||||
def stream_observations(self, get_observation_fn):
|
||||
"""Continuously stream observations to the server"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
print("Observation streaming thread starting")
|
||||
|
||||
first_observation = True
|
||||
while self.running:
|
||||
try:
|
||||
observation = get_observation_fn()
|
||||
|
||||
|
||||
# Set appropriate transfer state
|
||||
if first_observation:
|
||||
state = async_inference_pb2.TRANSFER_BEGIN
|
||||
@@ -148,53 +291,69 @@ class RobotClient:
|
||||
else:
|
||||
state = async_inference_pb2.TRANSFER_MIDDLE
|
||||
|
||||
# Build timestep element in observation
|
||||
# observation = np.hstack(
|
||||
# (np.array([self.latest_action]), observation)
|
||||
# ).astype(np.float32)
|
||||
|
||||
time.sleep(environment_dt) # Control the observation sending rate
|
||||
self.send_observation(observation, state)
|
||||
time.sleep(0.1) # Adjust rate as needed
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in observation sender: {e}")
|
||||
time.sleep(1)
|
||||
time.sleep(idle_wait)
|
||||
|
||||
def example_usage():
|
||||
|
||||
def async_client():
|
||||
# Example of how to use the RobotClient
|
||||
client = RobotClient()
|
||||
|
||||
|
||||
if client.start():
|
||||
# Function to generate mock observations
|
||||
def get_mock_observation():
|
||||
return np.random.randint(0, 10, size=10).astype(np.float32)
|
||||
|
||||
def get_observation():
|
||||
# Create a counter attribute if it doesn't exist
|
||||
if not hasattr(get_observation, "counter"):
|
||||
get_observation.counter = 0
|
||||
|
||||
# Create observation with incrementing first element
|
||||
observation = np.array([get_observation.counter, 0, 0], dtype=np.float32)
|
||||
|
||||
# Increment counter for next call
|
||||
get_observation.counter += 1
|
||||
|
||||
return observation
|
||||
|
||||
print("Starting all threads...")
|
||||
|
||||
# Create and start observation sender thread
|
||||
obs_thread = threading.Thread(
|
||||
target=client.stream_observations,
|
||||
args=(get_mock_observation,)
|
||||
)
|
||||
obs_thread = threading.Thread(target=client.stream_observations, args=(get_observation,))
|
||||
obs_thread.daemon = True
|
||||
obs_thread.start()
|
||||
|
||||
|
||||
# Create and start action receiver thread
|
||||
action_thread = threading.Thread(target=client.receive_actions)
|
||||
action_thread.daemon = True
|
||||
action_thread.start()
|
||||
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions)
|
||||
action_receiver_thread.daemon = True
|
||||
|
||||
# Create action execution thread
|
||||
action_execution_thread = threading.Thread(target=client.execute_actions)
|
||||
action_execution_thread.daemon = True
|
||||
|
||||
# Start all threads
|
||||
obs_thread.start()
|
||||
action_receiver_thread.start()
|
||||
action_execution_thread.start()
|
||||
|
||||
try:
|
||||
# Main loop - action execution
|
||||
while True:
|
||||
print(client.action_queue.qsize())
|
||||
action = client.get_next_action()
|
||||
if action is not None:
|
||||
print(f"Executing action: {action}")
|
||||
time.sleep(1)
|
||||
else:
|
||||
print("No action available")
|
||||
time.sleep(0.5)
|
||||
|
||||
# Main thread just keeps everything alive
|
||||
while client.running:
|
||||
time.sleep(idle_wait)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
finally:
|
||||
client.stop()
|
||||
print("Client stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
||||
|
||||
async_client()
|
||||
|
||||
Reference in New Issue
Block a user