fix: client sends timed objects only, and uses lock to read & write robot status
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from queue import Empty, Queue
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import async_inference_pb2 # type: ignore
|
||||
import async_inference_pb2_grpc # type: ignore
|
||||
import grpc
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
@@ -16,28 +16,40 @@ idle_wait = 0.1
|
||||
|
||||
|
||||
class TimedData:
|
||||
def __init__(self, timestep: int, data: np.ndarray):
|
||||
self.timestep = timestep
|
||||
def __init__(self, timestamp: float, data: Any, timestep: int):
|
||||
"""Initialize a TimedData object.
|
||||
|
||||
Args:
|
||||
timestamp: Unix timestamp relative to data's creation.
|
||||
data: The actual data to wrap a timestamp around.
|
||||
"""
|
||||
self.timestamp = timestamp
|
||||
self.data = data
|
||||
self.timestep = timestep
|
||||
|
||||
def get_data(self):
|
||||
return self.data
|
||||
|
||||
def get_timestamp(self):
|
||||
return self.timestamp
|
||||
|
||||
def get_timestep(self):
|
||||
return self.timestep
|
||||
|
||||
|
||||
class TimedAction(TimedData):
|
||||
def __init__(self, timestep: int, action: np.ndarray):
|
||||
super().__init__(timestep, action)
|
||||
def __init__(self, timestamp: float, action: torch.Tensor, timestep: int):
|
||||
super().__init__(timestamp=timestamp, data=action, timestep=timestep)
|
||||
|
||||
def get_action(self):
|
||||
return self.get_data()
|
||||
|
||||
|
||||
class TimedObservation(TimedData):
|
||||
def __init__(self, timestep: int, observation: np.ndarray, transfer_state: int = 0):
|
||||
super().__init__(timestep, observation)
|
||||
def __init__(
|
||||
self, timestamp: float, observation: dict[str, torch.Tensor], timestep: int, transfer_state: int = 0
|
||||
):
|
||||
super().__init__(timestamp=timestamp, data=observation, timestep=timestep)
|
||||
self.transfer_state = transfer_state
|
||||
|
||||
def get_observation(self):
|
||||
@@ -60,9 +72,10 @@ class RobotClient:
|
||||
self.action_chunk_size = 20
|
||||
|
||||
self.action_queue = Queue()
|
||||
self.start_barrier = threading.Barrier(3) # Barrier for 3 threads
|
||||
self.start_barrier = threading.Barrier(3)
|
||||
|
||||
self.observation_timestep = 0
|
||||
# Create a lock for robot access
|
||||
self.robot_lock = threading.Lock()
|
||||
|
||||
self.use_robot = use_robot
|
||||
if self.use_robot:
|
||||
@@ -72,8 +85,10 @@ class RobotClient:
|
||||
time.sleep(idle_wait) # sleep waiting for cameras to activate
|
||||
print("Robot connected")
|
||||
|
||||
def timesteps(self):
|
||||
"""Get the timesteps of the actions in the queue"""
|
||||
self.robot_reading = True
|
||||
|
||||
def timestamps(self):
|
||||
"""Get the timestamps of the actions in the queue"""
|
||||
return sorted([action.get_timestep() for action in self.action_queue.queue])
|
||||
|
||||
def start(self):
|
||||
@@ -93,11 +108,13 @@ class RobotClient:
|
||||
def stop(self):
|
||||
"""Stop the robot client"""
|
||||
self.running = False
|
||||
if self.use_robot and hasattr(self, "robot"):
|
||||
self.robot.disconnect()
|
||||
self.channel.close()
|
||||
|
||||
def send_observation(
|
||||
self,
|
||||
observation_data: np.ndarray,
|
||||
obs: TimedObservation,
|
||||
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
|
||||
) -> bool:
|
||||
"""Send observation to the policy server.
|
||||
@@ -106,10 +123,10 @@ class RobotClient:
|
||||
print("Client not running")
|
||||
return False
|
||||
|
||||
# Convert observation data to bytes
|
||||
observation_data = observation_data.tobytes()
|
||||
assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!"
|
||||
|
||||
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_data)
|
||||
observation_bytes = pickle.dumps(obs)
|
||||
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes)
|
||||
|
||||
try:
|
||||
_ = self.stub.SendObservations(iter([observation]))
|
||||
@@ -121,24 +138,19 @@ class RobotClient:
|
||||
print(f"Error sending observation: {e}")
|
||||
return False
|
||||
|
||||
def _validate_action(self, action: np.ndarray):
|
||||
"""Validate the action"""
|
||||
assert action.shape == (7,), f"Action shape must be (7,) (including timestep), got {action.shape}"
|
||||
def _validate_action(self, action: TimedAction):
|
||||
"""Received actions are keps only when they have been produced for now or later, never before"""
|
||||
return not action.get_timestamp() < self.latest_action
|
||||
|
||||
return True
|
||||
|
||||
def _validate_action_chunk(self, actions: list[np.ndarray]):
|
||||
"""Validate the action chunk"""
|
||||
def _validate_action_chunk(self, actions: list[TimedAction]):
|
||||
assert len(actions) == self.action_chunk_size, (
|
||||
f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}"
|
||||
)
|
||||
|
||||
assert all(self._validate_action(action) for action in actions), "Invalid action in chunk"
|
||||
|
||||
return True
|
||||
|
||||
def _inspect_action_queue(self):
|
||||
"""Inspect the action queue"""
|
||||
print("Queue size: ", self.action_queue.qsize())
|
||||
print("Queue contents: ", sorted([action.get_timestep() for action in self.action_queue.queue]))
|
||||
|
||||
@@ -151,13 +163,10 @@ class RobotClient:
|
||||
break
|
||||
|
||||
def _fill_action_queue(self, actions: list[TimedAction]):
|
||||
"""Fill the action queue with incoming actions"""
|
||||
"""Fill the action queue with incoming valid actions"""
|
||||
for action in actions:
|
||||
# Only keep the actions that are newer than the latest action
|
||||
if action.get_timestep() <= self.latest_action:
|
||||
continue
|
||||
|
||||
self.action_queue.put(action)
|
||||
if self._validate_action(action):
|
||||
self.action_queue.put(action)
|
||||
|
||||
def _update_action_queue(self, actions: list[TimedAction]):
|
||||
"""Aggregate incoming actions into the action queue.
|
||||
@@ -176,6 +185,7 @@ class RobotClient:
|
||||
Args:
|
||||
actions: List of TimedAction instances to queue
|
||||
"""
|
||||
print("*** Current latest action: ", self.latest_action, "***")
|
||||
print("\t**** Current queue content ****: ")
|
||||
self._inspect_action_queue()
|
||||
|
||||
@@ -188,55 +198,22 @@ class RobotClient:
|
||||
print("\t*** Queue after clearing and filling ****: ")
|
||||
self._inspect_action_queue()
|
||||
|
||||
def _create_timed_actions(self, action_data: np.ndarray) -> list[TimedAction]:
|
||||
"""Create TimedAction instances from raw action data.
|
||||
|
||||
Args:
|
||||
action_data: Numpy array of shape (chunk_size, 7) where first column
|
||||
is timestep and remaining columns are action values.
|
||||
|
||||
Returns:
|
||||
List of TimedAction instances
|
||||
"""
|
||||
timed_actions = []
|
||||
for action in action_data:
|
||||
timestep = int(action[0]) # First element is the timestep
|
||||
action_values = action[1:] # Remaining elements are the action
|
||||
timed_actions.append(TimedAction(timestep, action_values))
|
||||
|
||||
return timed_actions
|
||||
|
||||
def receive_actions(self):
|
||||
"""Receive actions from the policy server"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
print("Action receiving thread starting")
|
||||
|
||||
print(self.timesteps())
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Use StreamActions to get a stream of actions from the server
|
||||
action_chunks_counter = 0
|
||||
for action in self.stub.StreamActions(async_inference_pb2.Empty()):
|
||||
# Read the action data which includes timesteps
|
||||
# Shape is (chunk_size, 7) where first column is timestep
|
||||
action_data = np.frombuffer(action.data, dtype=np.float32).reshape(
|
||||
self.action_chunk_size, 7
|
||||
)
|
||||
|
||||
print("*** Receiving actions ****: ")
|
||||
# Convert raw action data to TimedAction instances
|
||||
timed_actions = self._create_timed_actions(action_data)
|
||||
for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()):
|
||||
# Deserialize bytes back into list[TimedAction]
|
||||
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
||||
|
||||
# strategy for queue composition is specified in the method
|
||||
self._clear_and_fill_action_queue(timed_actions)
|
||||
|
||||
action_chunks_counter += 1
|
||||
|
||||
if action_chunks_counter > 2:
|
||||
raise ValueError("Too many action chunks received")
|
||||
|
||||
except grpc.RpcError as e:
|
||||
print(f"Error receiving actions: {e}")
|
||||
time.sleep(idle_wait) # Avoid tight loop on error
|
||||
@@ -258,16 +235,24 @@ class RobotClient:
|
||||
|
||||
while self.running:
|
||||
# Get the next action from the queue
|
||||
time.sleep(environment_dt)
|
||||
timed_action = self._get_next_action()
|
||||
|
||||
if timed_action is not None:
|
||||
self.latest_action = timed_action.get_timestep()
|
||||
# self.latest_action = timed_action.get_timestep()
|
||||
self.latest_action = timed_action.get_timestamp()
|
||||
|
||||
# Convert action to tensor and send to robot
|
||||
if self.use_robot:
|
||||
self.robot.send_action(torch.tensor(timed_action.get_action()))
|
||||
|
||||
time.sleep(environment_dt)
|
||||
# Acquire lock before accessing the robot
|
||||
if self.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
|
||||
try:
|
||||
self.robot.send_action(timed_action.get_action())
|
||||
finally:
|
||||
# Always release the lock in a finally block to ensure it's released
|
||||
self.robot_lock.release()
|
||||
else:
|
||||
print("Could not acquire robot lock for action execution, retrying next cycle")
|
||||
|
||||
else:
|
||||
# No action available, wait and retry fetching from queue
|
||||
@@ -282,8 +267,14 @@ class RobotClient:
|
||||
first_observation = True
|
||||
while self.running:
|
||||
try:
|
||||
# Get serialized observation bytes from the function
|
||||
time.sleep(environment_dt)
|
||||
observation = get_observation_fn()
|
||||
|
||||
# Skip if observation is None (couldn't acquire lock)
|
||||
if observation is None:
|
||||
continue
|
||||
|
||||
# Set appropriate transfer state
|
||||
if first_observation:
|
||||
state = async_inference_pb2.TRANSFER_BEGIN
|
||||
@@ -291,12 +282,6 @@ class RobotClient:
|
||||
else:
|
||||
state = async_inference_pb2.TRANSFER_MIDDLE
|
||||
|
||||
# Build timestep element in observation
|
||||
# observation = np.hstack(
|
||||
# (np.array([self.latest_action]), observation)
|
||||
# ).astype(np.float32)
|
||||
|
||||
time.sleep(environment_dt) # Control the observation sending rate
|
||||
self.send_observation(observation, state)
|
||||
|
||||
except Exception as e:
|
||||
@@ -315,8 +300,21 @@ def async_client():
|
||||
if not hasattr(get_observation, "counter"):
|
||||
get_observation.counter = 0
|
||||
|
||||
# Create observation with incrementing first element
|
||||
observation = np.array([get_observation.counter, 0, 0], dtype=np.float32)
|
||||
# Acquire lock before accessing the robot
|
||||
observation_content = None
|
||||
if client.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
|
||||
try:
|
||||
observation_content = client.robot.capture_observation()
|
||||
finally:
|
||||
# Always release the lock in a finally block to ensure it's released
|
||||
client.robot_lock.release()
|
||||
else:
|
||||
print("Could not acquire robot lock for observation capture, skipping this cycle")
|
||||
return None # Return None to indicate no observation was captured
|
||||
|
||||
observation = TimedObservation(
|
||||
timestamp=time.time(), observation=observation_content, timestep=get_observation.counter
|
||||
)
|
||||
|
||||
# Increment counter for next call
|
||||
get_observation.counter += 1
|
||||
|
||||
Reference in New Issue
Block a user