From 989f3d05ba47f872d75c587e76838e9cc574857a Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Wed, 23 Jul 2025 16:30:01 +0700 Subject: [PATCH] [Async Inference] Merge Protos & refactoring (#1480) * Merge together proto files and refactor Async inference * Fixup for Async inference * Drop not reuqired changes * Fix tests * Drop old async files * Drop chunk_size param * Fix versions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix wrong fix Co-authored-by: Ben Zhang * Fixup --------- Co-authored-by: Michel Aractingi Co-authored-by: Ben Zhang Co-authored-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --- pyproject.toml | 6 +- src/lerobot/scripts/server/helpers.py | 86 ------ src/lerobot/scripts/server/policy_server.py | 34 +-- src/lerobot/scripts/server/robot_client.py | 25 +- src/lerobot/transport/async_inference.proto | 59 ---- src/lerobot/transport/async_inference_pb2.py | 45 --- .../transport/async_inference_pb2_grpc.py | 277 ------------------ src/lerobot/transport/services.proto | 28 ++ src/lerobot/transport/services_pb2.py | 28 +- src/lerobot/transport/services_pb2_grpc.py | 211 ++++++++++++- src/lerobot/transport/utils.py | 10 +- tests/async_inference/test_e2e.py | 8 +- 12 files changed, 299 insertions(+), 518 deletions(-) delete mode 100644 src/lerobot/transport/async_inference.proto delete mode 100644 src/lerobot/transport/async_inference_pb2.py delete mode 100644 src/lerobot/transport/async_inference_pb2_grpc.py diff --git a/pyproject.toml b/pyproject.toml index ec259897..7a0ad148 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ dependencies = [ pygame-dep = ["pygame>=2.5.1"] placo-dep = ["placo>=0.9.6"] transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency -grpcio-dep = ["grpcio==1.71.0"] +grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # Motors feetech = ["feetech-servo-sdk>=1.0.0"] @@ -119,14 +119,14 @@ intelrealsense = [ # Policies pi0 = ["lerobot[transformers-dep]"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"] -hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "protobuf>=5.29.3", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] +hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] # Features async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"] # Development docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"] -dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "grpcio-tools==1.71.0"] +dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"] test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"] video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] diff --git a/src/lerobot/scripts/server/helpers.py b/src/lerobot/scripts/server/helpers.py index 7fd56e69..d8051b76 100644 --- a/src/lerobot/scripts/server/helpers.py +++ b/src/lerobot/scripts/server/helpers.py @@ -12,15 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import io import logging import logging.handlers import os import time from dataclasses import dataclass from pathlib import Path -from threading import Event -from typing import Any import torch @@ -31,8 +28,6 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features # NOTE: Configs need to be loaded for the client to be able to instantiate the policy config from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401 from lerobot.robots.robot import Robot -from lerobot.transport import async_inference_pb2 -from lerobot.transport.utils import bytes_buffer_size from lerobot.utils.utils import init_logging Action = torch.Tensor @@ -303,84 +298,3 @@ def observations_similar( ) return _compare_observation_states(obs1_state, obs2_state, atol=atol) - - -def send_bytes_in_chunks( - buffer: bytes, - message_class: Any, - log_prefix: str = "", - silent: bool = True, - chunk_size: int = 3 * 1024 * 1024, -): - # NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.send_bytes_in_chunks. Duplication can't be avoided if we - # don't use a unique class for messages sent (due to the different transfer states sent). Also, I'd want more control over the - # chunk size as I am using it to send image observations. - buffer = io.BytesIO(buffer) - size_in_bytes = bytes_buffer_size(buffer) - - sent_bytes = 0 - - logging_method = logging.info if not silent else logging.debug - - logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with") - - while sent_bytes < size_in_bytes: - transfer_state = async_inference_pb2.TransferState.TRANSFER_MIDDLE - - if sent_bytes + chunk_size >= size_in_bytes: - transfer_state = async_inference_pb2.TransferState.TRANSFER_END - elif sent_bytes == 0: - transfer_state = async_inference_pb2.TransferState.TRANSFER_BEGIN - - size_to_read = min(chunk_size, size_in_bytes - sent_bytes) - chunk = buffer.read(size_to_read) - - yield message_class(transfer_state=transfer_state, data=chunk) - sent_bytes += size_to_read - logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}") - - logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") - - -def receive_bytes_in_chunks( - iterator, continue_receiving: Event, logger: logging.Logger, log_prefix: str = "" -): # type: ignore - # NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.receive_bytes_in_chunks. Duplication can't be avoided if we - # don't use a unique class for messages sent (due to the different transfer states sent). Also, on the server side the logic for receiving - # is opposite then the HIL-SERL design (my event showcases keeping on running instead of shutdown) - bytes_buffer = io.BytesIO() - step = 0 - - logger.info(f"{log_prefix} Starting receiver") - for item in iterator: - logger.debug(f"{log_prefix} Received item") - if not continue_receiving.is_set(): - logger.info(f"{log_prefix} Shutting down receiver") - return - - if item.transfer_state == async_inference_pb2.TransferState.TRANSFER_BEGIN: - bytes_buffer.seek(0) - bytes_buffer.truncate(0) - bytes_buffer.write(item.data) - logger.debug(f"{log_prefix} Received data at step 0") - - elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_MIDDLE: - bytes_buffer.write(item.data) - step += 1 - logger.debug(f"{log_prefix} Received data at step {step}") - - elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_END: - bytes_buffer.write(item.data) - logger.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}") - - complete_bytes = bytes_buffer.getvalue() - - bytes_buffer.seek(0) - bytes_buffer.truncate(0) - - logger.debug(f"{log_prefix} Queue updated") - return complete_bytes - - else: - logger.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}") - raise ValueError(f"Received unknown transfer state {item.transfer_state}") diff --git a/src/lerobot/scripts/server/policy_server.py b/src/lerobot/scripts/server/policy_server.py index 13ba976e..0ed446d3 100644 --- a/src/lerobot/scripts/server/policy_server.py +++ b/src/lerobot/scripts/server/policy_server.py @@ -49,21 +49,21 @@ from lerobot.scripts.server.helpers import ( get_logger, observations_similar, raw_observation_to_observation, - receive_bytes_in_chunks, ) from lerobot.transport import ( - async_inference_pb2, # type: ignore - async_inference_pb2_grpc, # type: ignore + services_pb2, # type: ignore + services_pb2_grpc, # type: ignore ) +from lerobot.transport.utils import receive_bytes_in_chunks -class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): +class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): prefix = "policy_server" logger = get_logger(prefix) def __init__(self, config: PolicyServerConfig): self.config = config - self._running_event = threading.Event() + self.shutdown_event = threading.Event() # FPS measurement self.fps_tracker = FPSTracker(target_fps=config.fps) @@ -84,7 +84,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): @property def running(self): - return self._running_event.is_set() + return not self.shutdown_event.is_set() @property def policy_image_features(self): @@ -93,7 +93,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): def _reset_server(self) -> None: """Flushes server state when new client connects.""" # only running inference on the latest observation received by the server - self._running_event.clear() + self.shutdown_event.set() self.observation_queue = Queue(maxsize=1) with self._predicted_timesteps_lock: @@ -103,16 +103,16 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): client_id = context.peer() self.logger.info(f"Client {client_id} connected and ready") self._reset_server() - self._running_event.set() + self.shutdown_event.clear() - return async_inference_pb2.Empty() + return services_pb2.Empty() def SendPolicyInstructions(self, request, context): # noqa: N802 """Receive policy instructions from the robot client""" if not self.running: self.logger.warning("Server is not running. Ignoring policy instructions.") - return async_inference_pb2.Empty() + return services_pb2.Empty() client_id = context.peer() @@ -149,7 +149,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds") - return async_inference_pb2.Empty() + return services_pb2.Empty() def SendObservations(self, request_iterator, context): # noqa: N802 """Receive observations from the robot client""" @@ -159,7 +159,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): receive_time = time.time() # comparing timestamps so need time.time() start_deserialize = time.perf_counter() received_bytes = receive_bytes_in_chunks( - request_iterator, self._running_event, self.logger + request_iterator, None, self.shutdown_event, self.logger ) # blocking call while looping over request_iterator timed_observation = pickle.loads(received_bytes) # nosec deserialize_time = time.perf_counter() - start_deserialize @@ -190,7 +190,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): ): self.logger.info(f"Observation #{obs_timestep} has been filtered out") - return async_inference_pb2.Empty() + return services_pb2.Empty() def GetActions(self, request, context): # noqa: N802 """Returns actions to the robot client. Actions are sent as a single @@ -218,7 +218,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): serialize_time = time.perf_counter() - start_time # Create and return the action chunk - actions = async_inference_pb2.Actions(data=actions_bytes) + actions = services_pb2.Actions(data=actions_bytes) self.logger.info( f"Action chunk #{obs.get_timestep()} generated | " @@ -239,12 +239,12 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): return actions except Empty: # no observation added to queue in obs_queue_timeout - return async_inference_pb2.Empty() + return services_pb2.Empty() except Exception as e: self.logger.error(f"Error in StreamActions: {e}") - return async_inference_pb2.Empty() + return services_pb2.Empty() def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool: """Check if the observation is valid to be processed by the policy""" @@ -388,7 +388,7 @@ def serve(cfg: PolicyServerConfig): # Setup and start gRPC server server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) - async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) + services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) server.add_insecure_port(f"{cfg.host}:{cfg.port}") policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}") diff --git a/src/lerobot/scripts/server/robot_client.py b/src/lerobot/scripts/server/robot_client.py index 68166de6..0599e068 100644 --- a/src/lerobot/scripts/server/robot_client.py +++ b/src/lerobot/scripts/server/robot_client.py @@ -69,15 +69,14 @@ from lerobot.scripts.server.helpers import ( TimedObservation, get_logger, map_robot_keys_to_lerobot_features, - send_bytes_in_chunks, validate_robot_cameras_for_policy, visualize_action_queue_size, ) from lerobot.transport import ( - async_inference_pb2, # type: ignore - async_inference_pb2_grpc, # type: ignore + services_pb2, # type: ignore + services_pb2_grpc, # type: ignore ) -from lerobot.transport.utils import grpc_channel_options +from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks class RobotClient: @@ -118,10 +117,10 @@ class RobotClient: self.channel = grpc.insecure_channel( self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s") ) - self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel) + self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel) self.logger.info(f"Initializing client to connect to server at {self.server_address}") - self._running_event = threading.Event() + self.shutdown_event = threading.Event() # Initialize client side variables self.latest_action_lock = threading.Lock() @@ -146,20 +145,20 @@ class RobotClient: @property def running(self): - return self._running_event.is_set() + return not self.shutdown_event.is_set() def start(self): """Start the robot client and connect to the policy server""" try: # client-server handshake start_time = time.perf_counter() - self.stub.Ready(async_inference_pb2.Empty()) + self.stub.Ready(services_pb2.Empty()) end_time = time.perf_counter() self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s") # send policy instructions policy_config_bytes = pickle.dumps(self.policy_config) - policy_setup = async_inference_pb2.PolicySetup(data=policy_config_bytes) + policy_setup = services_pb2.PolicySetup(data=policy_config_bytes) self.logger.info("Sending policy instructions to policy server") self.logger.debug( @@ -170,7 +169,7 @@ class RobotClient: self.stub.SendPolicyInstructions(policy_setup) - self._running_event.set() + self.shutdown_event.clear() return True @@ -180,7 +179,7 @@ class RobotClient: def stop(self): """Stop the robot client""" - self._running_event.clear() + self.shutdown_event.set() self.robot.disconnect() self.logger.debug("Robot disconnected") @@ -208,7 +207,7 @@ class RobotClient: try: observation_iterator = send_bytes_in_chunks( observation_bytes, - async_inference_pb2.Observation, + services_pb2.Observation, log_prefix="[CLIENT] Observation", silent=True, ) @@ -283,7 +282,7 @@ class RobotClient: while self.running: try: # Use StreamActions to get a stream of actions from the server - actions_chunk = self.stub.GetActions(async_inference_pb2.Empty()) + actions_chunk = self.stub.GetActions(services_pb2.Empty()) if len(actions_chunk.data) == 0: continue # received `Empty` from server, wait for next call diff --git a/src/lerobot/transport/async_inference.proto b/src/lerobot/transport/async_inference.proto deleted file mode 100644 index 434f3142..00000000 --- a/src/lerobot/transport/async_inference.proto +++ /dev/null @@ -1,59 +0,0 @@ -// fmt: off -// flake8: noqa -// !/usr/bin/env python - -// Copyright 2024 The HuggingFace Inc. team. -// All rights reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -syntax = "proto3"; - -package async_inference; - -// AsyncInference: from Robot perspective -// Robot send observations to & executes action received from a remote Policy server -service AsyncInference { - // Robot -> Policy to share observations with a remote inference server - // Policy -> Robot to share actions predicted for given observations - rpc SendObservations(stream Observation) returns (Empty); - rpc GetActions(Empty) returns (Actions); - rpc SendPolicyInstructions(PolicySetup) returns (Empty); - rpc Ready(Empty) returns (Empty); - rpc Stop(Empty) returns (Empty); -} - -enum TransferState { - TRANSFER_UNKNOWN = 0; - TRANSFER_BEGIN = 1; - TRANSFER_MIDDLE = 2; - TRANSFER_END = 3; -} - -// Messages -message Observation { - // sent by Robot, to remote Policy - TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size - bytes data = 2; -} - -message Actions { - // sent by remote Policy, to Robot - bytes data = 1; -} - -message PolicySetup { - // sent by Robot to remote server, to init Policy - bytes data = 1; -} - -message Empty {} diff --git a/src/lerobot/transport/async_inference_pb2.py b/src/lerobot/transport/async_inference_pb2.py deleted file mode 100644 index 59c8eb48..00000000 --- a/src/lerobot/transport/async_inference_pb2.py +++ /dev/null @@ -1,45 +0,0 @@ -# Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE -# source: async_inference.proto -# Protobuf Python Version: 5.29.0 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 5, - 29, - 0, - '', - 'async_inference.proto' -) -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdd\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12>\n\nGetActions\x12\x16.async_inference.Empty\x1a\x18.async_inference.Actions\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Empty\x12\x36\n\x04Stop\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_TRANSFERSTATE']._serialized_start=190 - _globals['_TRANSFERSTATE']._serialized_end=286 - _globals['_OBSERVATION']._serialized_start=42 - _globals['_OBSERVATION']._serialized_end=125 - _globals['_ACTIONS']._serialized_start=127 - _globals['_ACTIONS']._serialized_end=150 - _globals['_POLICYSETUP']._serialized_start=152 - _globals['_POLICYSETUP']._serialized_end=179 - _globals['_EMPTY']._serialized_start=181 - _globals['_EMPTY']._serialized_end=188 - _globals['_ASYNCINFERENCE']._serialized_start=289 - _globals['_ASYNCINFERENCE']._serialized_end=638 -# @@protoc_insertion_point(module_scope) diff --git a/src/lerobot/transport/async_inference_pb2_grpc.py b/src/lerobot/transport/async_inference_pb2_grpc.py deleted file mode 100644 index 3042db0d..00000000 --- a/src/lerobot/transport/async_inference_pb2_grpc.py +++ /dev/null @@ -1,277 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc -import warnings - -from lerobot.transport import async_inference_pb2 as async__inference__pb2 - -GRPC_GENERATED_VERSION = '1.71.0' -GRPC_VERSION = grpc.__version__ -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - raise RuntimeError( - f'The grpc package installed is at version {GRPC_VERSION},' - + f' but the generated code in async_inference_pb2_grpc.py depends on' - + f' grpcio>={GRPC_GENERATED_VERSION}.' - + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' - + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' - ) - - -class AsyncInferenceStub: - """AsyncInference: from Robot perspective - Robot send observations to & executes action received from a remote Policy server - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.SendObservations = channel.stream_unary( - '/async_inference.AsyncInference/SendObservations', - request_serializer=async__inference__pb2.Observation.SerializeToString, - response_deserializer=async__inference__pb2.Empty.FromString, - _registered_method=True) - self.GetActions = channel.unary_unary( - '/async_inference.AsyncInference/GetActions', - request_serializer=async__inference__pb2.Empty.SerializeToString, - response_deserializer=async__inference__pb2.Actions.FromString, - _registered_method=True) - self.SendPolicyInstructions = channel.unary_unary( - '/async_inference.AsyncInference/SendPolicyInstructions', - request_serializer=async__inference__pb2.PolicySetup.SerializeToString, - response_deserializer=async__inference__pb2.Empty.FromString, - _registered_method=True) - self.Ready = channel.unary_unary( - '/async_inference.AsyncInference/Ready', - request_serializer=async__inference__pb2.Empty.SerializeToString, - response_deserializer=async__inference__pb2.Empty.FromString, - _registered_method=True) - self.Stop = channel.unary_unary( - '/async_inference.AsyncInference/Stop', - request_serializer=async__inference__pb2.Empty.SerializeToString, - response_deserializer=async__inference__pb2.Empty.FromString, - _registered_method=True) - - -class AsyncInferenceServicer: - """AsyncInference: from Robot perspective - Robot send observations to & executes action received from a remote Policy server - """ - - def SendObservations(self, request_iterator, context): - """Robot -> Policy to share observations with a remote inference server - Policy -> Robot to share actions predicted for given observations - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetActions(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def SendPolicyInstructions(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Ready(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Stop(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_AsyncInferenceServicer_to_server(servicer, server): - rpc_method_handlers = { - 'SendObservations': grpc.stream_unary_rpc_method_handler( - servicer.SendObservations, - request_deserializer=async__inference__pb2.Observation.FromString, - response_serializer=async__inference__pb2.Empty.SerializeToString, - ), - 'GetActions': grpc.unary_unary_rpc_method_handler( - servicer.GetActions, - request_deserializer=async__inference__pb2.Empty.FromString, - response_serializer=async__inference__pb2.Actions.SerializeToString, - ), - 'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler( - servicer.SendPolicyInstructions, - request_deserializer=async__inference__pb2.PolicySetup.FromString, - response_serializer=async__inference__pb2.Empty.SerializeToString, - ), - 'Ready': grpc.unary_unary_rpc_method_handler( - servicer.Ready, - request_deserializer=async__inference__pb2.Empty.FromString, - response_serializer=async__inference__pb2.Empty.SerializeToString, - ), - 'Stop': grpc.unary_unary_rpc_method_handler( - servicer.Stop, - request_deserializer=async__inference__pb2.Empty.FromString, - response_serializer=async__inference__pb2.Empty.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'async_inference.AsyncInference', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers) - - - # This class is part of an EXPERIMENTAL API. -class AsyncInference: - """AsyncInference: from Robot perspective - Robot send observations to & executes action received from a remote Policy server - """ - - @staticmethod - def SendObservations(request_iterator, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.stream_unary( - request_iterator, - target, - '/async_inference.AsyncInference/SendObservations', - async__inference__pb2.Observation.SerializeToString, - async__inference__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def GetActions(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/async_inference.AsyncInference/GetActions', - async__inference__pb2.Empty.SerializeToString, - async__inference__pb2.Actions.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def SendPolicyInstructions(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/async_inference.AsyncInference/SendPolicyInstructions', - async__inference__pb2.PolicySetup.SerializeToString, - async__inference__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Ready(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/async_inference.AsyncInference/Ready', - async__inference__pb2.Empty.SerializeToString, - async__inference__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Stop(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/async_inference.AsyncInference/Stop', - async__inference__pb2.Empty.SerializeToString, - async__inference__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) diff --git a/src/lerobot/transport/services.proto b/src/lerobot/transport/services.proto index 70f39741..ea0c12de 100644 --- a/src/lerobot/transport/services.proto +++ b/src/lerobot/transport/services.proto @@ -33,6 +33,17 @@ service LearnerService { rpc Ready(Empty) returns (Empty); } +// AsyncInference: from Robot perspective +// Robot send observations to & executes action received from a remote Policy server +service AsyncInference { + // Robot -> Policy to share observations with a remote inference server + // Policy -> Robot to share actions predicted for given observations + rpc SendObservations(stream Observation) returns (Empty); + rpc GetActions(Empty) returns (Actions); + rpc SendPolicyInstructions(PolicySetup) returns (Empty); + rpc Ready(Empty) returns (Empty); +} + enum TransferState { TRANSFER_UNKNOWN = 0; TRANSFER_BEGIN = 1; @@ -56,4 +67,21 @@ message InteractionMessage { bytes data = 2; } +// Messages +message Observation { + // sent by Robot, to remote Policy + TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size + bytes data = 2; +} + +message Actions { + // sent by remote Policy, to Robot + bytes data = 1; +} + +message PolicySetup { + // sent by Robot to remote server, to init Policy + bytes data = 1; +} + message Empty {} diff --git a/src/lerobot/transport/services_pb2.py b/src/lerobot/transport/services_pb2.py index 9e66ae1e..05f2d174 100644 --- a/src/lerobot/transport/services_pb2.py +++ b/src/lerobot/transport/services_pb2.py @@ -1,7 +1,7 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE # source: lerobot/transport/services.proto -# Protobuf Python Version: 5.29.0 +# Protobuf Python Version: 6.31.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -10,8 +10,8 @@ from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder _runtime_version.ValidateProtobufRuntimeVersion( _runtime_version.Domain.PUBLIC, - 5, - 29, + 6, + 31, 0, '', 'lerobot/transport/services.proto' @@ -23,23 +23,31 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"M\n\x0bObservation\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Empty2\xf5\x01\n\x0e\x41syncInference\x12>\n\x10SendObservations\x12\x16.transport.Observation\x1a\x10.transport.Empty(\x01\x12\x32\n\nGetActions\x12\x10.transport.Empty\x1a\x12.transport.Actions\x12\x42\n\x16SendPolicyInstructions\x12\x16.transport.PolicySetup\x1a\x10.transport.Empty\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_TRANSFERSTATE']._serialized_start=298 - _globals['_TRANSFERSTATE']._serialized_end=394 + _globals['_TRANSFERSTATE']._serialized_start=431 + _globals['_TRANSFERSTATE']._serialized_end=527 _globals['_TRANSITION']._serialized_start=47 _globals['_TRANSITION']._serialized_end=123 _globals['_PARAMETERS']._serialized_start=125 _globals['_PARAMETERS']._serialized_end=201 _globals['_INTERACTIONMESSAGE']._serialized_start=203 _globals['_INTERACTIONMESSAGE']._serialized_end=287 - _globals['_EMPTY']._serialized_start=289 - _globals['_EMPTY']._serialized_end=296 - _globals['_LEARNERSERVICE']._serialized_start=397 - _globals['_LEARNERSERVICE']._serialized_end=654 + _globals['_OBSERVATION']._serialized_start=289 + _globals['_OBSERVATION']._serialized_end=366 + _globals['_ACTIONS']._serialized_start=368 + _globals['_ACTIONS']._serialized_end=391 + _globals['_POLICYSETUP']._serialized_start=393 + _globals['_POLICYSETUP']._serialized_end=420 + _globals['_EMPTY']._serialized_start=422 + _globals['_EMPTY']._serialized_end=429 + _globals['_LEARNERSERVICE']._serialized_start=530 + _globals['_LEARNERSERVICE']._serialized_end=787 + _globals['_ASYNCINFERENCE']._serialized_start=790 + _globals['_ASYNCINFERENCE']._serialized_end=1035 # @@protoc_insertion_point(module_scope) diff --git a/src/lerobot/transport/services_pb2_grpc.py b/src/lerobot/transport/services_pb2_grpc.py index 77801a34..35a01b67 100644 --- a/src/lerobot/transport/services_pb2_grpc.py +++ b/src/lerobot/transport/services_pb2_grpc.py @@ -5,7 +5,7 @@ import warnings from lerobot.transport import services_pb2 as lerobot_dot_transport_dot_services__pb2 -GRPC_GENERATED_VERSION = '1.71.0' +GRPC_GENERATED_VERSION = '1.73.1' GRPC_VERSION = grpc.__version__ _version_not_supported = False @@ -231,3 +231,212 @@ class LearnerService: timeout, metadata, _registered_method=True) + + +class AsyncInferenceStub: + """AsyncInference: from Robot perspective + Robot send observations to & executes action received from a remote Policy server + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SendObservations = channel.stream_unary( + '/transport.AsyncInference/SendObservations', + request_serializer=lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) + self.GetActions = channel.unary_unary( + '/transport.AsyncInference/GetActions', + request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Actions.FromString, + _registered_method=True) + self.SendPolicyInstructions = channel.unary_unary( + '/transport.AsyncInference/SendPolicyInstructions', + request_serializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) + self.Ready = channel.unary_unary( + '/transport.AsyncInference/Ready', + request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) + + +class AsyncInferenceServicer: + """AsyncInference: from Robot perspective + Robot send observations to & executes action received from a remote Policy server + """ + + def SendObservations(self, request_iterator, context): + """Robot -> Policy to share observations with a remote inference server + Policy -> Robot to share actions predicted for given observations + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetActions(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendPolicyInstructions(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Ready(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_AsyncInferenceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SendObservations': grpc.stream_unary_rpc_method_handler( + servicer.SendObservations, + request_deserializer=lerobot_dot_transport_dot_services__pb2.Observation.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), + 'GetActions': grpc.unary_unary_rpc_method_handler( + servicer.GetActions, + request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Actions.SerializeToString, + ), + 'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler( + servicer.SendPolicyInstructions, + request_deserializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), + 'Ready': grpc.unary_unary_rpc_method_handler( + servicer.Ready, + request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'transport.AsyncInference', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('transport.AsyncInference', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class AsyncInference: + """AsyncInference: from Robot perspective + Robot send observations to & executes action received from a remote Policy server + """ + + @staticmethod + def SendObservations(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/transport.AsyncInference/SendObservations', + lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetActions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/transport.AsyncInference/GetActions', + lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Actions.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendPolicyInstructions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/transport.AsyncInference/SendPolicyInstructions', + lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Ready(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/transport.AsyncInference/Ready', + lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/lerobot/transport/utils.py b/src/lerobot/transport/utils.py index bf1aab75..5c9f702f 100644 --- a/src/lerobot/transport/utils.py +++ b/src/lerobot/transport/utils.py @@ -19,7 +19,8 @@ import io import json import logging import pickle # nosec B403: Safe usage for internal serialization only -from multiprocessing import Event, Queue +from multiprocessing import Event +from queue import Queue from typing import Any import torch @@ -66,7 +67,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "" logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") -def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore +def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event, log_prefix: str = ""): bytes_buffer = io.BytesIO() step = 0 @@ -91,7 +92,10 @@ def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_p bytes_buffer.write(item.data) logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}") - queue.put(bytes_buffer.getvalue()) + if queue is not None: + queue.put(bytes_buffer.getvalue()) + else: + return bytes_buffer.getvalue() bytes_buffer.seek(0) bytes_buffer.truncate(0) diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py index d7b68e66..1c0400e6 100644 --- a/tests/async_inference/test_e2e.py +++ b/tests/async_inference/test_e2e.py @@ -54,8 +54,8 @@ def test_async_inference_e2e(monkeypatch): from lerobot.scripts.server.policy_server import PolicyServer from lerobot.scripts.server.robot_client import RobotClient from lerobot.transport import ( - async_inference_pb2, # type: ignore - async_inference_pb2_grpc, # type: ignore + services_pb2, # type: ignore + services_pb2_grpc, # type: ignore ) from tests.mocks.mock_robot import MockRobotConfig @@ -113,13 +113,13 @@ def test_async_inference_e2e(monkeypatch): # Bypass potentially heavy model loading inside SendPolicyInstructions def _fake_send_policy_instructions(self, request, context): # noqa: N802 - return async_inference_pb2.Empty() + return services_pb2.Empty() monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True) # Build gRPC server running a PolicyServer server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server")) - async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) + services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) # Use the host/port specified in the fixture's config server_address = f"{policy_server.config.host}:{policy_server.config.port}"