From 2cce85b5ddf14165cdcb1802da20505ba4836290 Mon Sep 17 00:00:00 2001 From: Francesco Capuano Date: Sat, 19 Apr 2025 14:34:36 +0200 Subject: [PATCH] fix: action chunks predicted using policy, and timed to observation used --- lerobot/scripts/server/policy_server.py | 145 +++++++++++++----------- 1 file changed, 78 insertions(+), 67 deletions(-) diff --git a/lerobot/scripts/server/policy_server.py b/lerobot/scripts/server/policy_server.py index 04b43b8c..47564ca8 100644 --- a/lerobot/scripts/server/policy_server.py +++ b/lerobot/scripts/server/policy_server.py @@ -1,38 +1,41 @@ import itertools +import pickle # nosec import time from concurrent import futures from queue import Queue -from typing import Generator, List +from typing import Generator, List, Optional import async_inference_pb2 # type: ignore import async_inference_pb2_grpc # type: ignore import grpc -import numpy as np import torch from datasets import load_dataset -from lerobot.common.policies.pretrained import PreTrainedPolicy -from lerobot.scripts.server.robot_client import TimedObservation +from lerobot.common.policies.act.modeling_act import ACTPolicy +from lerobot.scripts.server.robot_client import TimedAction, TimedObservation, environment_dt inference_latency = 1 / 3 idle_wait = 0.1 -def get_device(): - return torch.device("cuda" if torch.cuda.is_available() else "cpu") - - class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): - def __init__(self, policy: PreTrainedPolicy = None): - # TODO: Add code for loading and using policy for inference - self.policy = policy - + def __init__(self): # TODO: Add device specification for policy inference at init + self.device = "mps" + start = time.time() + self.policy = ACTPolicy.from_pretrained("fracapuano/act_so100_test") + self.policy.to(self.device) + end = time.time() + print(f"Time taken to put policy on {self.device}: {end - start} seconds") + # Initialize dataset action generator self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset()) self._setup_server() + self.actions_per_chunk = 20 + self.actions_overlap = 10 + def _setup_server(self) -> None: """Flushes server state when new client connects.""" # only running inference on the latest observation received by the server @@ -46,15 +49,11 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): def SendObservations(self, request_iterator, context): # noqa: N802 """Receive observations from the robot client""" - client_id = context.peer() - print(f"Receiving observations from {client_id}") - # print("Number of observations in queue: ", self.observation_queue.qsize()) + # client_id = context.peer() + # print(f"Receiving observations from {client_id}") for observation in request_iterator: - # Increment observation timestep counter for each new observation - observation_data = np.frombuffer(observation.data, dtype=np.float32) - observation_timestep = observation_data[0] - observation_content = observation_data[1:] + timed_observation = pickle.loads(observation.data) # nosec # If queue is full, get the old observation to make room if self.observation_queue.full(): @@ -62,14 +61,8 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): _ = self.observation_queue.get_nowait() # Now put the new observation (never blocks as queue is non-full here) - self.observation_queue.put( - TimedObservation( - timestep=int(observation_timestep), - observation=observation_content, - transfer_state=observation.transfer_state, - ) - ) - print("Received observation no: ", observation_timestep) + self.observation_queue.put(timed_observation) + print("Received observation no: ", timed_observation.get_timestep()) return async_inference_pb2.Empty() @@ -91,15 +84,45 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): return async_inference_pb2.Empty() - def _predict_and_queue_action(self, observation): + def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]: + """Turn a chunk of actions into a list of TimedAction instances, + with the first action corresponding to t_0 and the rest corresponding to + t_0 + i*environment_dt for i in range(len(action_chunk)) + """ + return [ + TimedAction(t_0 + i * environment_dt, action, i_0 + i) for i, action in enumerate(action_chunk) + ] + + @torch.no_grad() + def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]: """Predict an action based on the observation""" - # TODO: Implement the logic to predict an action based on the observation - """ - Ideally, action-prediction should be general and not specific to the policy used. - That is, this interface should be the same for ACT/VLA/RL-based etc. - """ - # TODO: Queue the action to be sent to the robot client - raise NotImplementedError("Not implemented") + self.policy.eval() + + observation = {} + for k, v in observation_t.get_observation().items(): + if "image" in k: + observation[k] = v.permute(2, 0, 1).unsqueeze(0).to(self.device) + else: + observation[k] = v.unsqueeze(0).to(self.device) + + # Remove batch dimension + action_tensor = self.policy.select_action(observation).squeeze(0) + + if action_tensor.dim() == 1: + # No chunk dimension, so repeat action to create a (dummy) chunk of actions + action_tensor = action_tensor.cpu().repeat(self.actions_per_chunk, 1) + + action_chunk = self._time_action_chunk( + observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep() + ) + + action_bytes = pickle.dumps(action_chunk) # nosec + # Create and return the Action message + action = async_inference_pb2.Action(transfer_state=observation_t.transfer_state, data=action_bytes) + + time.sleep(inference_latency) # slow action generation, emulates inference time (ACT is very fast) + + return action def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]: """Stream chunks of actions from a prerecorded dataset. @@ -113,56 +136,44 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): actions = dataset["action"] action_indices = torch.arange(len(actions)) - actions_per_chunk = 20 - actions_overlap = 10 - # 2. Chunk the iterable of tensors into chunks with 10 elements each # sending only first element for debugging - indices_chunks = action_indices.unfold(0, actions_per_chunk, actions_per_chunk - actions_overlap) + indices_chunks = action_indices.unfold( + 0, self.actions_per_chunk, self.actions_per_chunk - self.actions_overlap + ) for idx_chunk in indices_chunks: yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :] - # Non overlapping action chunks - # actions_chunks = torch.split(actions, 20) - # for action_chunk in actions_chunks: - # yield action_chunk - - def _predict_action_chunk(self, observation: TimedObservation): + def _read_action_chunk(self, observation: Optional[TimedObservation] = None): """Dummy function for predicting action chunk given observation. Instead of computing actions on-the-fly, this method streams actions from a prerecorded dataset. """ - transfer_state = 0 if not observation else observation.transfer_state + import warnings + + warnings.warn( + "This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2 + ) + + if not observation: + observation = TimedObservation(timestamp=time.time(), observation={}, timestep=0) + transfer_state = 0 + else: + transfer_state = observation.transfer_state # Get chunk of actions from the generator actions_chunk = next(self.action_generator) - # Convert the chunk of actions to a single contiguous numpy array - # For the so100 dataset, each action in the chunk is a tensor with 6 elements - actions_array = actions_chunk.numpy() - - # Create timesteps starting from the observation timestep - # Each action in the chunk gets a timestep starting from observation_timestep - # This indicates that the first action corresponds to the current observation, - # and subsequent actions are for future timesteps (and predicted observations!) - - timesteps = ( - np.arange(observation.timestep, observation.timestep + len(actions_array)) - .reshape(-1, 1) - .astype(np.float32) + # Return a list of TimedActions, with timestamps starting from the observation timestamp + action_data = self._time_action_chunk( + observation.get_timestamp(), actions_chunk, observation.get_timestep() ) - - # Create a combined array with timesteps and actions - # First column is the timestep, remaining columns are the action values - combined_array = np.hstack((timesteps, actions_array)) - - # Convert the numpy array to bytes for transmission - action_data = combined_array.astype(np.float32).tobytes() + action_bytes = pickle.dumps(action_data) # nosec # Create and return the Action message - action = async_inference_pb2.Action(transfer_state=transfer_state, data=action_data) + action = async_inference_pb2.Action(transfer_state=transfer_state, data=action_bytes) time.sleep(inference_latency) # slow action generation, emulates inference time