[Async Inference] Merge Protos & refactoring (#1480)

* Merge together proto files and refactor Async inference

* Fixup for Async inference

* Drop not reuqired changes

* Fix tests

* Drop old async files

* Drop chunk_size param

* Fix versions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix wrong fix

Co-authored-by: Ben Zhang <ben.zhang@uwaterloo.ca>

* Fixup

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Ben Zhang <ben.zhang@uwaterloo.ca>
Co-authored-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
Eugene Mironov
2025-07-23 16:30:01 +07:00
committed by GitHub
parent f5d6b5b3a7
commit 989f3d05ba
12 changed files with 299 additions and 518 deletions

View File

@@ -95,7 +95,7 @@ dependencies = [
pygame-dep = ["pygame>=2.5.1"]
placo-dep = ["placo>=0.9.6"]
transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency
grpcio-dep = ["grpcio==1.71.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0"]
@@ -119,14 +119,14 @@ intelrealsense = [
# Policies
pi0 = ["lerobot[transformers-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "protobuf>=5.29.3", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"]
# Development
docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"]
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "grpcio-tools==1.71.0"]
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"]
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]

View File

@@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import logging.handlers
import os
import time
from dataclasses import dataclass
from pathlib import Path
from threading import Event
from typing import Any
import torch
@@ -31,8 +28,6 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
from lerobot.robots.robot import Robot
from lerobot.transport import async_inference_pb2
from lerobot.transport.utils import bytes_buffer_size
from lerobot.utils.utils import init_logging
Action = torch.Tensor
@@ -303,84 +298,3 @@ def observations_similar(
)
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
def send_bytes_in_chunks(
buffer: bytes,
message_class: Any,
log_prefix: str = "",
silent: bool = True,
chunk_size: int = 3 * 1024 * 1024,
):
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.send_bytes_in_chunks. Duplication can't be avoided if we
# don't use a unique class for messages sent (due to the different transfer states sent). Also, I'd want more control over the
# chunk size as I am using it to send image observations.
buffer = io.BytesIO(buffer)
size_in_bytes = bytes_buffer_size(buffer)
sent_bytes = 0
logging_method = logging.info if not silent else logging.debug
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = async_inference_pb2.TransferState.TRANSFER_MIDDLE
if sent_bytes + chunk_size >= size_in_bytes:
transfer_state = async_inference_pb2.TransferState.TRANSFER_END
elif sent_bytes == 0:
transfer_state = async_inference_pb2.TransferState.TRANSFER_BEGIN
size_to_read = min(chunk_size, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read)
yield message_class(transfer_state=transfer_state, data=chunk)
sent_bytes += size_to_read
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(
iterator, continue_receiving: Event, logger: logging.Logger, log_prefix: str = ""
): # type: ignore
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.receive_bytes_in_chunks. Duplication can't be avoided if we
# don't use a unique class for messages sent (due to the different transfer states sent). Also, on the server side the logic for receiving
# is opposite then the HIL-SERL design (my event showcases keeping on running instead of shutdown)
bytes_buffer = io.BytesIO()
step = 0
logger.info(f"{log_prefix} Starting receiver")
for item in iterator:
logger.debug(f"{log_prefix} Received item")
if not continue_receiving.is_set():
logger.info(f"{log_prefix} Shutting down receiver")
return
if item.transfer_state == async_inference_pb2.TransferState.TRANSFER_BEGIN:
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
bytes_buffer.write(item.data)
logger.debug(f"{log_prefix} Received data at step 0")
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_MIDDLE:
bytes_buffer.write(item.data)
step += 1
logger.debug(f"{log_prefix} Received data at step {step}")
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_END:
bytes_buffer.write(item.data)
logger.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
complete_bytes = bytes_buffer.getvalue()
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
logger.debug(f"{log_prefix} Queue updated")
return complete_bytes
else:
logger.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}")
raise ValueError(f"Received unknown transfer state {item.transfer_state}")

View File

@@ -49,21 +49,21 @@ from lerobot.scripts.server.helpers import (
get_logger,
observations_similar,
raw_observation_to_observation,
receive_bytes_in_chunks,
)
from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import receive_bytes_in_chunks
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
prefix = "policy_server"
logger = get_logger(prefix)
def __init__(self, config: PolicyServerConfig):
self.config = config
self._running_event = threading.Event()
self.shutdown_event = threading.Event()
# FPS measurement
self.fps_tracker = FPSTracker(target_fps=config.fps)
@@ -84,7 +84,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
@property
def running(self):
return self._running_event.is_set()
return not self.shutdown_event.is_set()
@property
def policy_image_features(self):
@@ -93,7 +93,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
def _reset_server(self) -> None:
"""Flushes server state when new client connects."""
# only running inference on the latest observation received by the server
self._running_event.clear()
self.shutdown_event.set()
self.observation_queue = Queue(maxsize=1)
with self._predicted_timesteps_lock:
@@ -103,16 +103,16 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
client_id = context.peer()
self.logger.info(f"Client {client_id} connected and ready")
self._reset_server()
self._running_event.set()
self.shutdown_event.clear()
return async_inference_pb2.Empty()
return services_pb2.Empty()
def SendPolicyInstructions(self, request, context): # noqa: N802
"""Receive policy instructions from the robot client"""
if not self.running:
self.logger.warning("Server is not running. Ignoring policy instructions.")
return async_inference_pb2.Empty()
return services_pb2.Empty()
client_id = context.peer()
@@ -149,7 +149,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
return async_inference_pb2.Empty()
return services_pb2.Empty()
def SendObservations(self, request_iterator, context): # noqa: N802
"""Receive observations from the robot client"""
@@ -159,7 +159,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
receive_time = time.time() # comparing timestamps so need time.time()
start_deserialize = time.perf_counter()
received_bytes = receive_bytes_in_chunks(
request_iterator, self._running_event, self.logger
request_iterator, None, self.shutdown_event, self.logger
) # blocking call while looping over request_iterator
timed_observation = pickle.loads(received_bytes) # nosec
deserialize_time = time.perf_counter() - start_deserialize
@@ -190,7 +190,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
):
self.logger.info(f"Observation #{obs_timestep} has been filtered out")
return async_inference_pb2.Empty()
return services_pb2.Empty()
def GetActions(self, request, context): # noqa: N802
"""Returns actions to the robot client. Actions are sent as a single
@@ -218,7 +218,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
serialize_time = time.perf_counter() - start_time
# Create and return the action chunk
actions = async_inference_pb2.Actions(data=actions_bytes)
actions = services_pb2.Actions(data=actions_bytes)
self.logger.info(
f"Action chunk #{obs.get_timestep()} generated | "
@@ -239,12 +239,12 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
return actions
except Empty: # no observation added to queue in obs_queue_timeout
return async_inference_pb2.Empty()
return services_pb2.Empty()
except Exception as e:
self.logger.error(f"Error in StreamActions: {e}")
return async_inference_pb2.Empty()
return services_pb2.Empty()
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
"""Check if the observation is valid to be processed by the policy"""
@@ -388,7 +388,7 @@ def serve(cfg: PolicyServerConfig):
# Setup and start gRPC server
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
server.add_insecure_port(f"{cfg.host}:{cfg.port}")
policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")

View File

@@ -69,15 +69,14 @@ from lerobot.scripts.server.helpers import (
TimedObservation,
get_logger,
map_robot_keys_to_lerobot_features,
send_bytes_in_chunks,
validate_robot_cameras_for_policy,
visualize_action_queue_size,
)
from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import grpc_channel_options
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
class RobotClient:
@@ -118,10 +117,10 @@ class RobotClient:
self.channel = grpc.insecure_channel(
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
)
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel)
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
self._running_event = threading.Event()
self.shutdown_event = threading.Event()
# Initialize client side variables
self.latest_action_lock = threading.Lock()
@@ -146,20 +145,20 @@ class RobotClient:
@property
def running(self):
return self._running_event.is_set()
return not self.shutdown_event.is_set()
def start(self):
"""Start the robot client and connect to the policy server"""
try:
# client-server handshake
start_time = time.perf_counter()
self.stub.Ready(async_inference_pb2.Empty())
self.stub.Ready(services_pb2.Empty())
end_time = time.perf_counter()
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
# send policy instructions
policy_config_bytes = pickle.dumps(self.policy_config)
policy_setup = async_inference_pb2.PolicySetup(data=policy_config_bytes)
policy_setup = services_pb2.PolicySetup(data=policy_config_bytes)
self.logger.info("Sending policy instructions to policy server")
self.logger.debug(
@@ -170,7 +169,7 @@ class RobotClient:
self.stub.SendPolicyInstructions(policy_setup)
self._running_event.set()
self.shutdown_event.clear()
return True
@@ -180,7 +179,7 @@ class RobotClient:
def stop(self):
"""Stop the robot client"""
self._running_event.clear()
self.shutdown_event.set()
self.robot.disconnect()
self.logger.debug("Robot disconnected")
@@ -208,7 +207,7 @@ class RobotClient:
try:
observation_iterator = send_bytes_in_chunks(
observation_bytes,
async_inference_pb2.Observation,
services_pb2.Observation,
log_prefix="[CLIENT] Observation",
silent=True,
)
@@ -283,7 +282,7 @@ class RobotClient:
while self.running:
try:
# Use StreamActions to get a stream of actions from the server
actions_chunk = self.stub.GetActions(async_inference_pb2.Empty())
actions_chunk = self.stub.GetActions(services_pb2.Empty())
if len(actions_chunk.data) == 0:
continue # received `Empty` from server, wait for next call

View File

@@ -1,59 +0,0 @@
// fmt: off
// flake8: noqa
// !/usr/bin/env python
// Copyright 2024 The HuggingFace Inc. team.
// All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package async_inference;
// AsyncInference: from Robot perspective
// Robot send observations to & executes action received from a remote Policy server
service AsyncInference {
// Robot -> Policy to share observations with a remote inference server
// Policy -> Robot to share actions predicted for given observations
rpc SendObservations(stream Observation) returns (Empty);
rpc GetActions(Empty) returns (Actions);
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
rpc Ready(Empty) returns (Empty);
rpc Stop(Empty) returns (Empty);
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
TRANSFER_MIDDLE = 2;
TRANSFER_END = 3;
}
// Messages
message Observation {
// sent by Robot, to remote Policy
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
bytes data = 2;
}
message Actions {
// sent by remote Policy, to Robot
bytes data = 1;
}
message PolicySetup {
// sent by Robot to remote server, to init Policy
bytes data = 1;
}
message Empty {}

View File

@@ -1,45 +0,0 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: async_inference.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'async_inference.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdd\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12>\n\nGetActions\x12\x16.async_inference.Empty\x1a\x18.async_inference.Actions\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Empty\x12\x36\n\x04Stop\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=190
_globals['_TRANSFERSTATE']._serialized_end=286
_globals['_OBSERVATION']._serialized_start=42
_globals['_OBSERVATION']._serialized_end=125
_globals['_ACTIONS']._serialized_start=127
_globals['_ACTIONS']._serialized_end=150
_globals['_POLICYSETUP']._serialized_start=152
_globals['_POLICYSETUP']._serialized_end=179
_globals['_EMPTY']._serialized_start=181
_globals['_EMPTY']._serialized_end=188
_globals['_ASYNCINFERENCE']._serialized_start=289
_globals['_ASYNCINFERENCE']._serialized_end=638
# @@protoc_insertion_point(module_scope)

View File

@@ -1,277 +0,0 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
from lerobot.transport import async_inference_pb2 as async__inference__pb2
GRPC_GENERATED_VERSION = '1.71.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class AsyncInferenceStub:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendObservations = channel.stream_unary(
'/async_inference.AsyncInference/SendObservations',
request_serializer=async__inference__pb2.Observation.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.GetActions = channel.unary_unary(
'/async_inference.AsyncInference/GetActions',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Actions.FromString,
_registered_method=True)
self.SendPolicyInstructions = channel.unary_unary(
'/async_inference.AsyncInference/SendPolicyInstructions',
request_serializer=async__inference__pb2.PolicySetup.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/async_inference.AsyncInference/Ready',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.Stop = channel.unary_unary(
'/async_inference.AsyncInference/Stop',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
class AsyncInferenceServicer:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def SendObservations(self, request_iterator, context):
"""Robot -> Policy to share observations with a remote inference server
Policy -> Robot to share actions predicted for given observations
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetActions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendPolicyInstructions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Stop(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AsyncInferenceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendObservations': grpc.stream_unary_rpc_method_handler(
servicer.SendObservations,
request_deserializer=async__inference__pb2.Observation.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'GetActions': grpc.unary_unary_rpc_method_handler(
servicer.GetActions,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Actions.SerializeToString,
),
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
servicer.SendPolicyInstructions,
request_deserializer=async__inference__pb2.PolicySetup.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'Stop': grpc.unary_unary_rpc_method_handler(
servicer.Stop,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'async_inference.AsyncInference', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class AsyncInference:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
@staticmethod
def SendObservations(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/async_inference.AsyncInference/SendObservations',
async__inference__pb2.Observation.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def GetActions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/GetActions',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Actions.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendPolicyInstructions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/SendPolicyInstructions',
async__inference__pb2.PolicySetup.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Ready',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Stop(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Stop',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -33,6 +33,17 @@ service LearnerService {
rpc Ready(Empty) returns (Empty);
}
// AsyncInference: from Robot perspective
// Robot send observations to & executes action received from a remote Policy server
service AsyncInference {
// Robot -> Policy to share observations with a remote inference server
// Policy -> Robot to share actions predicted for given observations
rpc SendObservations(stream Observation) returns (Empty);
rpc GetActions(Empty) returns (Actions);
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
rpc Ready(Empty) returns (Empty);
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
@@ -56,4 +67,21 @@ message InteractionMessage {
bytes data = 2;
}
// Messages
message Observation {
// sent by Robot, to remote Policy
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
bytes data = 2;
}
message Actions {
// sent by remote Policy, to Robot
bytes data = 1;
}
message PolicySetup {
// sent by Robot to remote server, to init Policy
bytes data = 1;
}
message Empty {}

View File

@@ -1,7 +1,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: lerobot/transport/services.proto
# Protobuf Python Version: 5.29.0
# Protobuf Python Version: 6.31.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -10,8 +10,8 @@ from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
6,
31,
0,
'',
'lerobot/transport/services.proto'
@@ -23,23 +23,31 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"M\n\x0bObservation\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Empty2\xf5\x01\n\x0e\x41syncInference\x12>\n\x10SendObservations\x12\x16.transport.Observation\x1a\x10.transport.Empty(\x01\x12\x32\n\nGetActions\x12\x10.transport.Empty\x1a\x12.transport.Actions\x12\x42\n\x16SendPolicyInstructions\x12\x16.transport.PolicySetup\x1a\x10.transport.Empty\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=298
_globals['_TRANSFERSTATE']._serialized_end=394
_globals['_TRANSFERSTATE']._serialized_start=431
_globals['_TRANSFERSTATE']._serialized_end=527
_globals['_TRANSITION']._serialized_start=47
_globals['_TRANSITION']._serialized_end=123
_globals['_PARAMETERS']._serialized_start=125
_globals['_PARAMETERS']._serialized_end=201
_globals['_INTERACTIONMESSAGE']._serialized_start=203
_globals['_INTERACTIONMESSAGE']._serialized_end=287
_globals['_EMPTY']._serialized_start=289
_globals['_EMPTY']._serialized_end=296
_globals['_LEARNERSERVICE']._serialized_start=397
_globals['_LEARNERSERVICE']._serialized_end=654
_globals['_OBSERVATION']._serialized_start=289
_globals['_OBSERVATION']._serialized_end=366
_globals['_ACTIONS']._serialized_start=368
_globals['_ACTIONS']._serialized_end=391
_globals['_POLICYSETUP']._serialized_start=393
_globals['_POLICYSETUP']._serialized_end=420
_globals['_EMPTY']._serialized_start=422
_globals['_EMPTY']._serialized_end=429
_globals['_LEARNERSERVICE']._serialized_start=530
_globals['_LEARNERSERVICE']._serialized_end=787
_globals['_ASYNCINFERENCE']._serialized_start=790
_globals['_ASYNCINFERENCE']._serialized_end=1035
# @@protoc_insertion_point(module_scope)

View File

@@ -5,7 +5,7 @@ import warnings
from lerobot.transport import services_pb2 as lerobot_dot_transport_dot_services__pb2
GRPC_GENERATED_VERSION = '1.71.0'
GRPC_GENERATED_VERSION = '1.73.1'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
@@ -231,3 +231,212 @@ class LearnerService:
timeout,
metadata,
_registered_method=True)
class AsyncInferenceStub:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendObservations = channel.stream_unary(
'/transport.AsyncInference/SendObservations',
request_serializer=lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
self.GetActions = channel.unary_unary(
'/transport.AsyncInference/GetActions',
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Actions.FromString,
_registered_method=True)
self.SendPolicyInstructions = channel.unary_unary(
'/transport.AsyncInference/SendPolicyInstructions',
request_serializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/transport.AsyncInference/Ready',
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
class AsyncInferenceServicer:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def SendObservations(self, request_iterator, context):
"""Robot -> Policy to share observations with a remote inference server
Policy -> Robot to share actions predicted for given observations
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetActions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendPolicyInstructions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AsyncInferenceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendObservations': grpc.stream_unary_rpc_method_handler(
servicer.SendObservations,
request_deserializer=lerobot_dot_transport_dot_services__pb2.Observation.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
'GetActions': grpc.unary_unary_rpc_method_handler(
servicer.GetActions,
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Actions.SerializeToString,
),
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
servicer.SendPolicyInstructions,
request_deserializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'transport.AsyncInference', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('transport.AsyncInference', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class AsyncInference:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
@staticmethod
def SendObservations(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/transport.AsyncInference/SendObservations',
lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def GetActions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/transport.AsyncInference/GetActions',
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Actions.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendPolicyInstructions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/transport.AsyncInference/SendPolicyInstructions',
lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/transport.AsyncInference/Ready',
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -19,7 +19,8 @@ import io
import json
import logging
import pickle # nosec B403: Safe usage for internal serialization only
from multiprocessing import Event, Queue
from multiprocessing import Event
from queue import Queue
from typing import Any
import torch
@@ -66,7 +67,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore
def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event, log_prefix: str = ""):
bytes_buffer = io.BytesIO()
step = 0
@@ -91,7 +92,10 @@ def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_p
bytes_buffer.write(item.data)
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
queue.put(bytes_buffer.getvalue())
if queue is not None:
queue.put(bytes_buffer.getvalue())
else:
return bytes_buffer.getvalue()
bytes_buffer.seek(0)
bytes_buffer.truncate(0)

View File

@@ -54,8 +54,8 @@ def test_async_inference_e2e(monkeypatch):
from lerobot.scripts.server.policy_server import PolicyServer
from lerobot.scripts.server.robot_client import RobotClient
from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from tests.mocks.mock_robot import MockRobotConfig
@@ -113,13 +113,13 @@ def test_async_inference_e2e(monkeypatch):
# Bypass potentially heavy model loading inside SendPolicyInstructions
def _fake_send_policy_instructions(self, request, context): # noqa: N802
return async_inference_pb2.Empty()
return services_pb2.Empty()
monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True)
# Build gRPC server running a PolicyServer
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server"))
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
# Use the host/port specified in the fixture's config
server_address = f"{policy_server.config.host}:{policy_server.config.port}"