fix: action chunks predicted using policy, and timed to observation used

This commit is contained in:
Francesco Capuano
2025-04-19 14:34:36 +02:00
parent b2d003e6eb
commit 2cce85b5dd

View File

@@ -1,38 +1,41 @@
import itertools import itertools
import pickle # nosec
import time import time
from concurrent import futures from concurrent import futures
from queue import Queue from queue import Queue
from typing import Generator, List from typing import Generator, List, Optional
import async_inference_pb2 # type: ignore import async_inference_pb2 # type: ignore
import async_inference_pb2_grpc # type: ignore import async_inference_pb2_grpc # type: ignore
import grpc import grpc
import numpy as np
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.scripts.server.robot_client import TimedObservation from lerobot.scripts.server.robot_client import TimedAction, TimedObservation, environment_dt
inference_latency = 1 / 3 inference_latency = 1 / 3
idle_wait = 0.1 idle_wait = 0.1
def get_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
def __init__(self, policy: PreTrainedPolicy = None): def __init__(self):
# TODO: Add code for loading and using policy for inference
self.policy = policy
# TODO: Add device specification for policy inference at init # TODO: Add device specification for policy inference at init
self.device = "mps"
start = time.time()
self.policy = ACTPolicy.from_pretrained("fracapuano/act_so100_test")
self.policy.to(self.device)
end = time.time()
print(f"Time taken to put policy on {self.device}: {end - start} seconds")
# Initialize dataset action generator # Initialize dataset action generator
self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset()) self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset())
self._setup_server() self._setup_server()
self.actions_per_chunk = 20
self.actions_overlap = 10
def _setup_server(self) -> None: def _setup_server(self) -> None:
"""Flushes server state when new client connects.""" """Flushes server state when new client connects."""
# only running inference on the latest observation received by the server # only running inference on the latest observation received by the server
@@ -46,15 +49,11 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
def SendObservations(self, request_iterator, context): # noqa: N802 def SendObservations(self, request_iterator, context): # noqa: N802
"""Receive observations from the robot client""" """Receive observations from the robot client"""
client_id = context.peer() # client_id = context.peer()
print(f"Receiving observations from {client_id}") # print(f"Receiving observations from {client_id}")
# print("Number of observations in queue: ", self.observation_queue.qsize())
for observation in request_iterator: for observation in request_iterator:
# Increment observation timestep counter for each new observation timed_observation = pickle.loads(observation.data) # nosec
observation_data = np.frombuffer(observation.data, dtype=np.float32)
observation_timestep = observation_data[0]
observation_content = observation_data[1:]
# If queue is full, get the old observation to make room # If queue is full, get the old observation to make room
if self.observation_queue.full(): if self.observation_queue.full():
@@ -62,14 +61,8 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
_ = self.observation_queue.get_nowait() _ = self.observation_queue.get_nowait()
# Now put the new observation (never blocks as queue is non-full here) # Now put the new observation (never blocks as queue is non-full here)
self.observation_queue.put( self.observation_queue.put(timed_observation)
TimedObservation( print("Received observation no: ", timed_observation.get_timestep())
timestep=int(observation_timestep),
observation=observation_content,
transfer_state=observation.transfer_state,
)
)
print("Received observation no: ", observation_timestep)
return async_inference_pb2.Empty() return async_inference_pb2.Empty()
@@ -91,15 +84,45 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
return async_inference_pb2.Empty() return async_inference_pb2.Empty()
def _predict_and_queue_action(self, observation): def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
"""Turn a chunk of actions into a list of TimedAction instances,
with the first action corresponding to t_0 and the rest corresponding to
t_0 + i*environment_dt for i in range(len(action_chunk))
"""
return [
TimedAction(t_0 + i * environment_dt, action, i_0 + i) for i, action in enumerate(action_chunk)
]
@torch.no_grad()
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
"""Predict an action based on the observation""" """Predict an action based on the observation"""
# TODO: Implement the logic to predict an action based on the observation self.policy.eval()
"""
Ideally, action-prediction should be general and not specific to the policy used. observation = {}
That is, this interface should be the same for ACT/VLA/RL-based etc. for k, v in observation_t.get_observation().items():
""" if "image" in k:
# TODO: Queue the action to be sent to the robot client observation[k] = v.permute(2, 0, 1).unsqueeze(0).to(self.device)
raise NotImplementedError("Not implemented") else:
observation[k] = v.unsqueeze(0).to(self.device)
# Remove batch dimension
action_tensor = self.policy.select_action(observation).squeeze(0)
if action_tensor.dim() == 1:
# No chunk dimension, so repeat action to create a (dummy) chunk of actions
action_tensor = action_tensor.cpu().repeat(self.actions_per_chunk, 1)
action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
)
action_bytes = pickle.dumps(action_chunk) # nosec
# Create and return the Action message
action = async_inference_pb2.Action(transfer_state=observation_t.transfer_state, data=action_bytes)
time.sleep(inference_latency) # slow action generation, emulates inference time (ACT is very fast)
return action
def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]: def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]:
"""Stream chunks of actions from a prerecorded dataset. """Stream chunks of actions from a prerecorded dataset.
@@ -113,56 +136,44 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
actions = dataset["action"] actions = dataset["action"]
action_indices = torch.arange(len(actions)) action_indices = torch.arange(len(actions))
actions_per_chunk = 20
actions_overlap = 10
# 2. Chunk the iterable of tensors into chunks with 10 elements each # 2. Chunk the iterable of tensors into chunks with 10 elements each
# sending only first element for debugging # sending only first element for debugging
indices_chunks = action_indices.unfold(0, actions_per_chunk, actions_per_chunk - actions_overlap) indices_chunks = action_indices.unfold(
0, self.actions_per_chunk, self.actions_per_chunk - self.actions_overlap
)
for idx_chunk in indices_chunks: for idx_chunk in indices_chunks:
yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :] yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :]
# Non overlapping action chunks def _read_action_chunk(self, observation: Optional[TimedObservation] = None):
# actions_chunks = torch.split(actions, 20)
# for action_chunk in actions_chunks:
# yield action_chunk
def _predict_action_chunk(self, observation: TimedObservation):
"""Dummy function for predicting action chunk given observation. """Dummy function for predicting action chunk given observation.
Instead of computing actions on-the-fly, this method streams Instead of computing actions on-the-fly, this method streams
actions from a prerecorded dataset. actions from a prerecorded dataset.
""" """
transfer_state = 0 if not observation else observation.transfer_state import warnings
warnings.warn(
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2
)
if not observation:
observation = TimedObservation(timestamp=time.time(), observation={}, timestep=0)
transfer_state = 0
else:
transfer_state = observation.transfer_state
# Get chunk of actions from the generator # Get chunk of actions from the generator
actions_chunk = next(self.action_generator) actions_chunk = next(self.action_generator)
# Convert the chunk of actions to a single contiguous numpy array # Return a list of TimedActions, with timestamps starting from the observation timestamp
# For the so100 dataset, each action in the chunk is a tensor with 6 elements action_data = self._time_action_chunk(
actions_array = actions_chunk.numpy() observation.get_timestamp(), actions_chunk, observation.get_timestep()
# Create timesteps starting from the observation timestep
# Each action in the chunk gets a timestep starting from observation_timestep
# This indicates that the first action corresponds to the current observation,
# and subsequent actions are for future timesteps (and predicted observations!)
timesteps = (
np.arange(observation.timestep, observation.timestep + len(actions_array))
.reshape(-1, 1)
.astype(np.float32)
) )
action_bytes = pickle.dumps(action_data) # nosec
# Create a combined array with timesteps and actions
# First column is the timestep, remaining columns are the action values
combined_array = np.hstack((timesteps, actions_array))
# Convert the numpy array to bytes for transmission
action_data = combined_array.astype(np.float32).tobytes()
# Create and return the Action message # Create and return the Action message
action = async_inference_pb2.Action(transfer_state=transfer_state, data=action_data) action = async_inference_pb2.Action(transfer_state=transfer_state, data=action_bytes)
time.sleep(inference_latency) # slow action generation, emulates inference time time.sleep(inference_latency) # slow action generation, emulates inference time