fix: action chunks predicted using policy, and timed to observation used

This commit is contained in:
Francesco Capuano
2025-04-19 14:34:36 +02:00
parent b2d003e6eb
commit 2cce85b5dd

View File

@@ -1,38 +1,41 @@
import itertools
import pickle # nosec
import time
from concurrent import futures
from queue import Queue
from typing import Generator, List
from typing import Generator, List, Optional
import async_inference_pb2 # type: ignore
import async_inference_pb2_grpc # type: ignore
import grpc
import numpy as np
import torch
from datasets import load_dataset
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.scripts.server.robot_client import TimedObservation
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.scripts.server.robot_client import TimedAction, TimedObservation, environment_dt
inference_latency = 1 / 3
idle_wait = 0.1
def get_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
def __init__(self, policy: PreTrainedPolicy = None):
# TODO: Add code for loading and using policy for inference
self.policy = policy
def __init__(self):
# TODO: Add device specification for policy inference at init
self.device = "mps"
start = time.time()
self.policy = ACTPolicy.from_pretrained("fracapuano/act_so100_test")
self.policy.to(self.device)
end = time.time()
print(f"Time taken to put policy on {self.device}: {end - start} seconds")
# Initialize dataset action generator
self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset())
self._setup_server()
self.actions_per_chunk = 20
self.actions_overlap = 10
def _setup_server(self) -> None:
"""Flushes server state when new client connects."""
# only running inference on the latest observation received by the server
@@ -46,15 +49,11 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
def SendObservations(self, request_iterator, context): # noqa: N802
"""Receive observations from the robot client"""
client_id = context.peer()
print(f"Receiving observations from {client_id}")
# print("Number of observations in queue: ", self.observation_queue.qsize())
# client_id = context.peer()
# print(f"Receiving observations from {client_id}")
for observation in request_iterator:
# Increment observation timestep counter for each new observation
observation_data = np.frombuffer(observation.data, dtype=np.float32)
observation_timestep = observation_data[0]
observation_content = observation_data[1:]
timed_observation = pickle.loads(observation.data) # nosec
# If queue is full, get the old observation to make room
if self.observation_queue.full():
@@ -62,14 +61,8 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
_ = self.observation_queue.get_nowait()
# Now put the new observation (never blocks as queue is non-full here)
self.observation_queue.put(
TimedObservation(
timestep=int(observation_timestep),
observation=observation_content,
transfer_state=observation.transfer_state,
)
)
print("Received observation no: ", observation_timestep)
self.observation_queue.put(timed_observation)
print("Received observation no: ", timed_observation.get_timestep())
return async_inference_pb2.Empty()
@@ -91,15 +84,45 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
return async_inference_pb2.Empty()
def _predict_and_queue_action(self, observation):
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
"""Turn a chunk of actions into a list of TimedAction instances,
with the first action corresponding to t_0 and the rest corresponding to
t_0 + i*environment_dt for i in range(len(action_chunk))
"""
return [
TimedAction(t_0 + i * environment_dt, action, i_0 + i) for i, action in enumerate(action_chunk)
]
@torch.no_grad()
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
"""Predict an action based on the observation"""
# TODO: Implement the logic to predict an action based on the observation
"""
Ideally, action-prediction should be general and not specific to the policy used.
That is, this interface should be the same for ACT/VLA/RL-based etc.
"""
# TODO: Queue the action to be sent to the robot client
raise NotImplementedError("Not implemented")
self.policy.eval()
observation = {}
for k, v in observation_t.get_observation().items():
if "image" in k:
observation[k] = v.permute(2, 0, 1).unsqueeze(0).to(self.device)
else:
observation[k] = v.unsqueeze(0).to(self.device)
# Remove batch dimension
action_tensor = self.policy.select_action(observation).squeeze(0)
if action_tensor.dim() == 1:
# No chunk dimension, so repeat action to create a (dummy) chunk of actions
action_tensor = action_tensor.cpu().repeat(self.actions_per_chunk, 1)
action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
)
action_bytes = pickle.dumps(action_chunk) # nosec
# Create and return the Action message
action = async_inference_pb2.Action(transfer_state=observation_t.transfer_state, data=action_bytes)
time.sleep(inference_latency) # slow action generation, emulates inference time (ACT is very fast)
return action
def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]:
"""Stream chunks of actions from a prerecorded dataset.
@@ -113,56 +136,44 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
actions = dataset["action"]
action_indices = torch.arange(len(actions))
actions_per_chunk = 20
actions_overlap = 10
# 2. Chunk the iterable of tensors into chunks with 10 elements each
# sending only first element for debugging
indices_chunks = action_indices.unfold(0, actions_per_chunk, actions_per_chunk - actions_overlap)
indices_chunks = action_indices.unfold(
0, self.actions_per_chunk, self.actions_per_chunk - self.actions_overlap
)
for idx_chunk in indices_chunks:
yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :]
# Non overlapping action chunks
# actions_chunks = torch.split(actions, 20)
# for action_chunk in actions_chunks:
# yield action_chunk
def _predict_action_chunk(self, observation: TimedObservation):
def _read_action_chunk(self, observation: Optional[TimedObservation] = None):
"""Dummy function for predicting action chunk given observation.
Instead of computing actions on-the-fly, this method streams
actions from a prerecorded dataset.
"""
transfer_state = 0 if not observation else observation.transfer_state
import warnings
warnings.warn(
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2
)
if not observation:
observation = TimedObservation(timestamp=time.time(), observation={}, timestep=0)
transfer_state = 0
else:
transfer_state = observation.transfer_state
# Get chunk of actions from the generator
actions_chunk = next(self.action_generator)
# Convert the chunk of actions to a single contiguous numpy array
# For the so100 dataset, each action in the chunk is a tensor with 6 elements
actions_array = actions_chunk.numpy()
# Create timesteps starting from the observation timestep
# Each action in the chunk gets a timestep starting from observation_timestep
# This indicates that the first action corresponds to the current observation,
# and subsequent actions are for future timesteps (and predicted observations!)
timesteps = (
np.arange(observation.timestep, observation.timestep + len(actions_array))
.reshape(-1, 1)
.astype(np.float32)
# Return a list of TimedActions, with timestamps starting from the observation timestamp
action_data = self._time_action_chunk(
observation.get_timestamp(), actions_chunk, observation.get_timestep()
)
# Create a combined array with timesteps and actions
# First column is the timestep, remaining columns are the action values
combined_array = np.hstack((timesteps, actions_array))
# Convert the numpy array to bytes for transmission
action_data = combined_array.astype(np.float32).tobytes()
action_bytes = pickle.dumps(action_data) # nosec
# Create and return the Action message
action = async_inference_pb2.Action(transfer_state=transfer_state, data=action_data)
action = async_inference_pb2.Action(transfer_state=transfer_state, data=action_bytes)
time.sleep(inference_latency) # slow action generation, emulates inference time