fix: client sends timed objects only, and uses lock to read & write robot status
This commit is contained in:
@@ -1,12 +1,12 @@
|
|||||||
|
import pickle # nosec
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import async_inference_pb2 # type: ignore
|
import async_inference_pb2 # type: ignore
|
||||||
import async_inference_pb2_grpc # type: ignore
|
import async_inference_pb2_grpc # type: ignore
|
||||||
import grpc
|
import grpc
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||||
@@ -16,28 +16,40 @@ idle_wait = 0.1
|
|||||||
|
|
||||||
|
|
||||||
class TimedData:
|
class TimedData:
|
||||||
def __init__(self, timestep: int, data: np.ndarray):
|
def __init__(self, timestamp: float, data: Any, timestep: int):
|
||||||
self.timestep = timestep
|
"""Initialize a TimedData object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestamp: Unix timestamp relative to data's creation.
|
||||||
|
data: The actual data to wrap a timestamp around.
|
||||||
|
"""
|
||||||
|
self.timestamp = timestamp
|
||||||
self.data = data
|
self.data = data
|
||||||
|
self.timestep = timestep
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
return self.data
|
return self.data
|
||||||
|
|
||||||
|
def get_timestamp(self):
|
||||||
|
return self.timestamp
|
||||||
|
|
||||||
def get_timestep(self):
|
def get_timestep(self):
|
||||||
return self.timestep
|
return self.timestep
|
||||||
|
|
||||||
|
|
||||||
class TimedAction(TimedData):
|
class TimedAction(TimedData):
|
||||||
def __init__(self, timestep: int, action: np.ndarray):
|
def __init__(self, timestamp: float, action: torch.Tensor, timestep: int):
|
||||||
super().__init__(timestep, action)
|
super().__init__(timestamp=timestamp, data=action, timestep=timestep)
|
||||||
|
|
||||||
def get_action(self):
|
def get_action(self):
|
||||||
return self.get_data()
|
return self.get_data()
|
||||||
|
|
||||||
|
|
||||||
class TimedObservation(TimedData):
|
class TimedObservation(TimedData):
|
||||||
def __init__(self, timestep: int, observation: np.ndarray, transfer_state: int = 0):
|
def __init__(
|
||||||
super().__init__(timestep, observation)
|
self, timestamp: float, observation: dict[str, torch.Tensor], timestep: int, transfer_state: int = 0
|
||||||
|
):
|
||||||
|
super().__init__(timestamp=timestamp, data=observation, timestep=timestep)
|
||||||
self.transfer_state = transfer_state
|
self.transfer_state = transfer_state
|
||||||
|
|
||||||
def get_observation(self):
|
def get_observation(self):
|
||||||
@@ -60,9 +72,10 @@ class RobotClient:
|
|||||||
self.action_chunk_size = 20
|
self.action_chunk_size = 20
|
||||||
|
|
||||||
self.action_queue = Queue()
|
self.action_queue = Queue()
|
||||||
self.start_barrier = threading.Barrier(3) # Barrier for 3 threads
|
self.start_barrier = threading.Barrier(3)
|
||||||
|
|
||||||
self.observation_timestep = 0
|
# Create a lock for robot access
|
||||||
|
self.robot_lock = threading.Lock()
|
||||||
|
|
||||||
self.use_robot = use_robot
|
self.use_robot = use_robot
|
||||||
if self.use_robot:
|
if self.use_robot:
|
||||||
@@ -72,8 +85,10 @@ class RobotClient:
|
|||||||
time.sleep(idle_wait) # sleep waiting for cameras to activate
|
time.sleep(idle_wait) # sleep waiting for cameras to activate
|
||||||
print("Robot connected")
|
print("Robot connected")
|
||||||
|
|
||||||
def timesteps(self):
|
self.robot_reading = True
|
||||||
"""Get the timesteps of the actions in the queue"""
|
|
||||||
|
def timestamps(self):
|
||||||
|
"""Get the timestamps of the actions in the queue"""
|
||||||
return sorted([action.get_timestep() for action in self.action_queue.queue])
|
return sorted([action.get_timestep() for action in self.action_queue.queue])
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
@@ -93,11 +108,13 @@ class RobotClient:
|
|||||||
def stop(self):
|
def stop(self):
|
||||||
"""Stop the robot client"""
|
"""Stop the robot client"""
|
||||||
self.running = False
|
self.running = False
|
||||||
|
if self.use_robot and hasattr(self, "robot"):
|
||||||
|
self.robot.disconnect()
|
||||||
self.channel.close()
|
self.channel.close()
|
||||||
|
|
||||||
def send_observation(
|
def send_observation(
|
||||||
self,
|
self,
|
||||||
observation_data: np.ndarray,
|
obs: TimedObservation,
|
||||||
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
|
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Send observation to the policy server.
|
"""Send observation to the policy server.
|
||||||
@@ -106,10 +123,10 @@ class RobotClient:
|
|||||||
print("Client not running")
|
print("Client not running")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Convert observation data to bytes
|
assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!"
|
||||||
observation_data = observation_data.tobytes()
|
|
||||||
|
|
||||||
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_data)
|
observation_bytes = pickle.dumps(obs)
|
||||||
|
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_ = self.stub.SendObservations(iter([observation]))
|
_ = self.stub.SendObservations(iter([observation]))
|
||||||
@@ -121,24 +138,19 @@ class RobotClient:
|
|||||||
print(f"Error sending observation: {e}")
|
print(f"Error sending observation: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _validate_action(self, action: np.ndarray):
|
def _validate_action(self, action: TimedAction):
|
||||||
"""Validate the action"""
|
"""Received actions are keps only when they have been produced for now or later, never before"""
|
||||||
assert action.shape == (7,), f"Action shape must be (7,) (including timestep), got {action.shape}"
|
return not action.get_timestamp() < self.latest_action
|
||||||
|
|
||||||
return True
|
def _validate_action_chunk(self, actions: list[TimedAction]):
|
||||||
|
|
||||||
def _validate_action_chunk(self, actions: list[np.ndarray]):
|
|
||||||
"""Validate the action chunk"""
|
|
||||||
assert len(actions) == self.action_chunk_size, (
|
assert len(actions) == self.action_chunk_size, (
|
||||||
f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}"
|
f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert all(self._validate_action(action) for action in actions), "Invalid action in chunk"
|
assert all(self._validate_action(action) for action in actions), "Invalid action in chunk"
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _inspect_action_queue(self):
|
def _inspect_action_queue(self):
|
||||||
"""Inspect the action queue"""
|
|
||||||
print("Queue size: ", self.action_queue.qsize())
|
print("Queue size: ", self.action_queue.qsize())
|
||||||
print("Queue contents: ", sorted([action.get_timestep() for action in self.action_queue.queue]))
|
print("Queue contents: ", sorted([action.get_timestep() for action in self.action_queue.queue]))
|
||||||
|
|
||||||
@@ -151,13 +163,10 @@ class RobotClient:
|
|||||||
break
|
break
|
||||||
|
|
||||||
def _fill_action_queue(self, actions: list[TimedAction]):
|
def _fill_action_queue(self, actions: list[TimedAction]):
|
||||||
"""Fill the action queue with incoming actions"""
|
"""Fill the action queue with incoming valid actions"""
|
||||||
for action in actions:
|
for action in actions:
|
||||||
# Only keep the actions that are newer than the latest action
|
if self._validate_action(action):
|
||||||
if action.get_timestep() <= self.latest_action:
|
self.action_queue.put(action)
|
||||||
continue
|
|
||||||
|
|
||||||
self.action_queue.put(action)
|
|
||||||
|
|
||||||
def _update_action_queue(self, actions: list[TimedAction]):
|
def _update_action_queue(self, actions: list[TimedAction]):
|
||||||
"""Aggregate incoming actions into the action queue.
|
"""Aggregate incoming actions into the action queue.
|
||||||
@@ -176,6 +185,7 @@ class RobotClient:
|
|||||||
Args:
|
Args:
|
||||||
actions: List of TimedAction instances to queue
|
actions: List of TimedAction instances to queue
|
||||||
"""
|
"""
|
||||||
|
print("*** Current latest action: ", self.latest_action, "***")
|
||||||
print("\t**** Current queue content ****: ")
|
print("\t**** Current queue content ****: ")
|
||||||
self._inspect_action_queue()
|
self._inspect_action_queue()
|
||||||
|
|
||||||
@@ -188,55 +198,22 @@ class RobotClient:
|
|||||||
print("\t*** Queue after clearing and filling ****: ")
|
print("\t*** Queue after clearing and filling ****: ")
|
||||||
self._inspect_action_queue()
|
self._inspect_action_queue()
|
||||||
|
|
||||||
def _create_timed_actions(self, action_data: np.ndarray) -> list[TimedAction]:
|
|
||||||
"""Create TimedAction instances from raw action data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_data: Numpy array of shape (chunk_size, 7) where first column
|
|
||||||
is timestep and remaining columns are action values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of TimedAction instances
|
|
||||||
"""
|
|
||||||
timed_actions = []
|
|
||||||
for action in action_data:
|
|
||||||
timestep = int(action[0]) # First element is the timestep
|
|
||||||
action_values = action[1:] # Remaining elements are the action
|
|
||||||
timed_actions.append(TimedAction(timestep, action_values))
|
|
||||||
|
|
||||||
return timed_actions
|
|
||||||
|
|
||||||
def receive_actions(self):
|
def receive_actions(self):
|
||||||
"""Receive actions from the policy server"""
|
"""Receive actions from the policy server"""
|
||||||
# Wait at barrier for synchronized start
|
# Wait at barrier for synchronized start
|
||||||
self.start_barrier.wait()
|
self.start_barrier.wait()
|
||||||
print("Action receiving thread starting")
|
print("Action receiving thread starting")
|
||||||
|
|
||||||
print(self.timesteps())
|
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
# Use StreamActions to get a stream of actions from the server
|
# Use StreamActions to get a stream of actions from the server
|
||||||
action_chunks_counter = 0
|
for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()):
|
||||||
for action in self.stub.StreamActions(async_inference_pb2.Empty()):
|
# Deserialize bytes back into list[TimedAction]
|
||||||
# Read the action data which includes timesteps
|
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
||||||
# Shape is (chunk_size, 7) where first column is timestep
|
|
||||||
action_data = np.frombuffer(action.data, dtype=np.float32).reshape(
|
|
||||||
self.action_chunk_size, 7
|
|
||||||
)
|
|
||||||
|
|
||||||
print("*** Receiving actions ****: ")
|
|
||||||
# Convert raw action data to TimedAction instances
|
|
||||||
timed_actions = self._create_timed_actions(action_data)
|
|
||||||
|
|
||||||
# strategy for queue composition is specified in the method
|
# strategy for queue composition is specified in the method
|
||||||
self._clear_and_fill_action_queue(timed_actions)
|
self._clear_and_fill_action_queue(timed_actions)
|
||||||
|
|
||||||
action_chunks_counter += 1
|
|
||||||
|
|
||||||
if action_chunks_counter > 2:
|
|
||||||
raise ValueError("Too many action chunks received")
|
|
||||||
|
|
||||||
except grpc.RpcError as e:
|
except grpc.RpcError as e:
|
||||||
print(f"Error receiving actions: {e}")
|
print(f"Error receiving actions: {e}")
|
||||||
time.sleep(idle_wait) # Avoid tight loop on error
|
time.sleep(idle_wait) # Avoid tight loop on error
|
||||||
@@ -258,16 +235,24 @@ class RobotClient:
|
|||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
# Get the next action from the queue
|
# Get the next action from the queue
|
||||||
|
time.sleep(environment_dt)
|
||||||
timed_action = self._get_next_action()
|
timed_action = self._get_next_action()
|
||||||
|
|
||||||
if timed_action is not None:
|
if timed_action is not None:
|
||||||
self.latest_action = timed_action.get_timestep()
|
# self.latest_action = timed_action.get_timestep()
|
||||||
|
self.latest_action = timed_action.get_timestamp()
|
||||||
|
|
||||||
# Convert action to tensor and send to robot
|
# Convert action to tensor and send to robot
|
||||||
if self.use_robot:
|
if self.use_robot:
|
||||||
self.robot.send_action(torch.tensor(timed_action.get_action()))
|
# Acquire lock before accessing the robot
|
||||||
|
if self.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
|
||||||
time.sleep(environment_dt)
|
try:
|
||||||
|
self.robot.send_action(timed_action.get_action())
|
||||||
|
finally:
|
||||||
|
# Always release the lock in a finally block to ensure it's released
|
||||||
|
self.robot_lock.release()
|
||||||
|
else:
|
||||||
|
print("Could not acquire robot lock for action execution, retrying next cycle")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# No action available, wait and retry fetching from queue
|
# No action available, wait and retry fetching from queue
|
||||||
@@ -282,8 +267,14 @@ class RobotClient:
|
|||||||
first_observation = True
|
first_observation = True
|
||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
|
# Get serialized observation bytes from the function
|
||||||
|
time.sleep(environment_dt)
|
||||||
observation = get_observation_fn()
|
observation = get_observation_fn()
|
||||||
|
|
||||||
|
# Skip if observation is None (couldn't acquire lock)
|
||||||
|
if observation is None:
|
||||||
|
continue
|
||||||
|
|
||||||
# Set appropriate transfer state
|
# Set appropriate transfer state
|
||||||
if first_observation:
|
if first_observation:
|
||||||
state = async_inference_pb2.TRANSFER_BEGIN
|
state = async_inference_pb2.TRANSFER_BEGIN
|
||||||
@@ -291,12 +282,6 @@ class RobotClient:
|
|||||||
else:
|
else:
|
||||||
state = async_inference_pb2.TRANSFER_MIDDLE
|
state = async_inference_pb2.TRANSFER_MIDDLE
|
||||||
|
|
||||||
# Build timestep element in observation
|
|
||||||
# observation = np.hstack(
|
|
||||||
# (np.array([self.latest_action]), observation)
|
|
||||||
# ).astype(np.float32)
|
|
||||||
|
|
||||||
time.sleep(environment_dt) # Control the observation sending rate
|
|
||||||
self.send_observation(observation, state)
|
self.send_observation(observation, state)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -315,8 +300,21 @@ def async_client():
|
|||||||
if not hasattr(get_observation, "counter"):
|
if not hasattr(get_observation, "counter"):
|
||||||
get_observation.counter = 0
|
get_observation.counter = 0
|
||||||
|
|
||||||
# Create observation with incrementing first element
|
# Acquire lock before accessing the robot
|
||||||
observation = np.array([get_observation.counter, 0, 0], dtype=np.float32)
|
observation_content = None
|
||||||
|
if client.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
|
||||||
|
try:
|
||||||
|
observation_content = client.robot.capture_observation()
|
||||||
|
finally:
|
||||||
|
# Always release the lock in a finally block to ensure it's released
|
||||||
|
client.robot_lock.release()
|
||||||
|
else:
|
||||||
|
print("Could not acquire robot lock for observation capture, skipping this cycle")
|
||||||
|
return None # Return None to indicate no observation was captured
|
||||||
|
|
||||||
|
observation = TimedObservation(
|
||||||
|
timestamp=time.time(), observation=observation_content, timestep=get_observation.counter
|
||||||
|
)
|
||||||
|
|
||||||
# Increment counter for next call
|
# Increment counter for next call
|
||||||
get_observation.counter += 1
|
get_observation.counter += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user