[Async Inference] Merge Protos & refactoring (#1480)

* Merge together proto files and refactor Async inference

* Fixup for Async inference

* Drop not reuqired changes

* Fix tests

* Drop old async files

* Drop chunk_size param

* Fix versions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix wrong fix

Co-authored-by: Ben Zhang <ben.zhang@uwaterloo.ca>

* Fixup

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Ben Zhang <ben.zhang@uwaterloo.ca>
Co-authored-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
Eugene Mironov
2025-07-23 16:30:01 +07:00
committed by GitHub
parent f5d6b5b3a7
commit 989f3d05ba
12 changed files with 299 additions and 518 deletions

View File

@@ -95,7 +95,7 @@ dependencies = [
pygame-dep = ["pygame>=2.5.1"] pygame-dep = ["pygame>=2.5.1"]
placo-dep = ["placo>=0.9.6"] placo-dep = ["placo>=0.9.6"]
transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency
grpcio-dep = ["grpcio==1.71.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
# Motors # Motors
feetech = ["feetech-servo-sdk>=1.0.0"] feetech = ["feetech-servo-sdk>=1.0.0"]
@@ -119,14 +119,14 @@ intelrealsense = [
# Policies # Policies
pi0 = ["lerobot[transformers-dep]"] pi0 = ["lerobot[transformers-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "protobuf>=5.29.3", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features # Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"] async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"]
# Development # Development
docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"] docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"]
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "grpcio-tools==1.71.0"] dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"] test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"]
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]

View File

@@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io
import logging import logging
import logging.handlers import logging.handlers
import os import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from threading import Event
from typing import Any
import torch import torch
@@ -31,8 +28,6 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config # NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401 from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
from lerobot.robots.robot import Robot from lerobot.robots.robot import Robot
from lerobot.transport import async_inference_pb2
from lerobot.transport.utils import bytes_buffer_size
from lerobot.utils.utils import init_logging from lerobot.utils.utils import init_logging
Action = torch.Tensor Action = torch.Tensor
@@ -303,84 +298,3 @@ def observations_similar(
) )
return _compare_observation_states(obs1_state, obs2_state, atol=atol) return _compare_observation_states(obs1_state, obs2_state, atol=atol)
def send_bytes_in_chunks(
buffer: bytes,
message_class: Any,
log_prefix: str = "",
silent: bool = True,
chunk_size: int = 3 * 1024 * 1024,
):
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.send_bytes_in_chunks. Duplication can't be avoided if we
# don't use a unique class for messages sent (due to the different transfer states sent). Also, I'd want more control over the
# chunk size as I am using it to send image observations.
buffer = io.BytesIO(buffer)
size_in_bytes = bytes_buffer_size(buffer)
sent_bytes = 0
logging_method = logging.info if not silent else logging.debug
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = async_inference_pb2.TransferState.TRANSFER_MIDDLE
if sent_bytes + chunk_size >= size_in_bytes:
transfer_state = async_inference_pb2.TransferState.TRANSFER_END
elif sent_bytes == 0:
transfer_state = async_inference_pb2.TransferState.TRANSFER_BEGIN
size_to_read = min(chunk_size, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read)
yield message_class(transfer_state=transfer_state, data=chunk)
sent_bytes += size_to_read
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(
iterator, continue_receiving: Event, logger: logging.Logger, log_prefix: str = ""
): # type: ignore
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.receive_bytes_in_chunks. Duplication can't be avoided if we
# don't use a unique class for messages sent (due to the different transfer states sent). Also, on the server side the logic for receiving
# is opposite then the HIL-SERL design (my event showcases keeping on running instead of shutdown)
bytes_buffer = io.BytesIO()
step = 0
logger.info(f"{log_prefix} Starting receiver")
for item in iterator:
logger.debug(f"{log_prefix} Received item")
if not continue_receiving.is_set():
logger.info(f"{log_prefix} Shutting down receiver")
return
if item.transfer_state == async_inference_pb2.TransferState.TRANSFER_BEGIN:
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
bytes_buffer.write(item.data)
logger.debug(f"{log_prefix} Received data at step 0")
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_MIDDLE:
bytes_buffer.write(item.data)
step += 1
logger.debug(f"{log_prefix} Received data at step {step}")
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_END:
bytes_buffer.write(item.data)
logger.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
complete_bytes = bytes_buffer.getvalue()
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
logger.debug(f"{log_prefix} Queue updated")
return complete_bytes
else:
logger.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}")
raise ValueError(f"Received unknown transfer state {item.transfer_state}")

View File

@@ -49,21 +49,21 @@ from lerobot.scripts.server.helpers import (
get_logger, get_logger,
observations_similar, observations_similar,
raw_observation_to_observation, raw_observation_to_observation,
receive_bytes_in_chunks,
) )
from lerobot.transport import ( from lerobot.transport import (
async_inference_pb2, # type: ignore services_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore services_pb2_grpc, # type: ignore
) )
from lerobot.transport.utils import receive_bytes_in_chunks
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
prefix = "policy_server" prefix = "policy_server"
logger = get_logger(prefix) logger = get_logger(prefix)
def __init__(self, config: PolicyServerConfig): def __init__(self, config: PolicyServerConfig):
self.config = config self.config = config
self._running_event = threading.Event() self.shutdown_event = threading.Event()
# FPS measurement # FPS measurement
self.fps_tracker = FPSTracker(target_fps=config.fps) self.fps_tracker = FPSTracker(target_fps=config.fps)
@@ -84,7 +84,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
@property @property
def running(self): def running(self):
return self._running_event.is_set() return not self.shutdown_event.is_set()
@property @property
def policy_image_features(self): def policy_image_features(self):
@@ -93,7 +93,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
def _reset_server(self) -> None: def _reset_server(self) -> None:
"""Flushes server state when new client connects.""" """Flushes server state when new client connects."""
# only running inference on the latest observation received by the server # only running inference on the latest observation received by the server
self._running_event.clear() self.shutdown_event.set()
self.observation_queue = Queue(maxsize=1) self.observation_queue = Queue(maxsize=1)
with self._predicted_timesteps_lock: with self._predicted_timesteps_lock:
@@ -103,16 +103,16 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
client_id = context.peer() client_id = context.peer()
self.logger.info(f"Client {client_id} connected and ready") self.logger.info(f"Client {client_id} connected and ready")
self._reset_server() self._reset_server()
self._running_event.set() self.shutdown_event.clear()
return async_inference_pb2.Empty() return services_pb2.Empty()
def SendPolicyInstructions(self, request, context): # noqa: N802 def SendPolicyInstructions(self, request, context): # noqa: N802
"""Receive policy instructions from the robot client""" """Receive policy instructions from the robot client"""
if not self.running: if not self.running:
self.logger.warning("Server is not running. Ignoring policy instructions.") self.logger.warning("Server is not running. Ignoring policy instructions.")
return async_inference_pb2.Empty() return services_pb2.Empty()
client_id = context.peer() client_id = context.peer()
@@ -149,7 +149,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds") self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
return async_inference_pb2.Empty() return services_pb2.Empty()
def SendObservations(self, request_iterator, context): # noqa: N802 def SendObservations(self, request_iterator, context): # noqa: N802
"""Receive observations from the robot client""" """Receive observations from the robot client"""
@@ -159,7 +159,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
receive_time = time.time() # comparing timestamps so need time.time() receive_time = time.time() # comparing timestamps so need time.time()
start_deserialize = time.perf_counter() start_deserialize = time.perf_counter()
received_bytes = receive_bytes_in_chunks( received_bytes = receive_bytes_in_chunks(
request_iterator, self._running_event, self.logger request_iterator, None, self.shutdown_event, self.logger
) # blocking call while looping over request_iterator ) # blocking call while looping over request_iterator
timed_observation = pickle.loads(received_bytes) # nosec timed_observation = pickle.loads(received_bytes) # nosec
deserialize_time = time.perf_counter() - start_deserialize deserialize_time = time.perf_counter() - start_deserialize
@@ -190,7 +190,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
): ):
self.logger.info(f"Observation #{obs_timestep} has been filtered out") self.logger.info(f"Observation #{obs_timestep} has been filtered out")
return async_inference_pb2.Empty() return services_pb2.Empty()
def GetActions(self, request, context): # noqa: N802 def GetActions(self, request, context): # noqa: N802
"""Returns actions to the robot client. Actions are sent as a single """Returns actions to the robot client. Actions are sent as a single
@@ -218,7 +218,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
serialize_time = time.perf_counter() - start_time serialize_time = time.perf_counter() - start_time
# Create and return the action chunk # Create and return the action chunk
actions = async_inference_pb2.Actions(data=actions_bytes) actions = services_pb2.Actions(data=actions_bytes)
self.logger.info( self.logger.info(
f"Action chunk #{obs.get_timestep()} generated | " f"Action chunk #{obs.get_timestep()} generated | "
@@ -239,12 +239,12 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
return actions return actions
except Empty: # no observation added to queue in obs_queue_timeout except Empty: # no observation added to queue in obs_queue_timeout
return async_inference_pb2.Empty() return services_pb2.Empty()
except Exception as e: except Exception as e:
self.logger.error(f"Error in StreamActions: {e}") self.logger.error(f"Error in StreamActions: {e}")
return async_inference_pb2.Empty() return services_pb2.Empty()
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool: def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
"""Check if the observation is valid to be processed by the policy""" """Check if the observation is valid to be processed by the policy"""
@@ -388,7 +388,7 @@ def serve(cfg: PolicyServerConfig):
# Setup and start gRPC server # Setup and start gRPC server
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
server.add_insecure_port(f"{cfg.host}:{cfg.port}") server.add_insecure_port(f"{cfg.host}:{cfg.port}")
policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}") policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")

View File

@@ -69,15 +69,14 @@ from lerobot.scripts.server.helpers import (
TimedObservation, TimedObservation,
get_logger, get_logger,
map_robot_keys_to_lerobot_features, map_robot_keys_to_lerobot_features,
send_bytes_in_chunks,
validate_robot_cameras_for_policy, validate_robot_cameras_for_policy,
visualize_action_queue_size, visualize_action_queue_size,
) )
from lerobot.transport import ( from lerobot.transport import (
async_inference_pb2, # type: ignore services_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore services_pb2_grpc, # type: ignore
) )
from lerobot.transport.utils import grpc_channel_options from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
class RobotClient: class RobotClient:
@@ -118,10 +117,10 @@ class RobotClient:
self.channel = grpc.insecure_channel( self.channel = grpc.insecure_channel(
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s") self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
) )
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel) self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel)
self.logger.info(f"Initializing client to connect to server at {self.server_address}") self.logger.info(f"Initializing client to connect to server at {self.server_address}")
self._running_event = threading.Event() self.shutdown_event = threading.Event()
# Initialize client side variables # Initialize client side variables
self.latest_action_lock = threading.Lock() self.latest_action_lock = threading.Lock()
@@ -146,20 +145,20 @@ class RobotClient:
@property @property
def running(self): def running(self):
return self._running_event.is_set() return not self.shutdown_event.is_set()
def start(self): def start(self):
"""Start the robot client and connect to the policy server""" """Start the robot client and connect to the policy server"""
try: try:
# client-server handshake # client-server handshake
start_time = time.perf_counter() start_time = time.perf_counter()
self.stub.Ready(async_inference_pb2.Empty()) self.stub.Ready(services_pb2.Empty())
end_time = time.perf_counter() end_time = time.perf_counter()
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s") self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
# send policy instructions # send policy instructions
policy_config_bytes = pickle.dumps(self.policy_config) policy_config_bytes = pickle.dumps(self.policy_config)
policy_setup = async_inference_pb2.PolicySetup(data=policy_config_bytes) policy_setup = services_pb2.PolicySetup(data=policy_config_bytes)
self.logger.info("Sending policy instructions to policy server") self.logger.info("Sending policy instructions to policy server")
self.logger.debug( self.logger.debug(
@@ -170,7 +169,7 @@ class RobotClient:
self.stub.SendPolicyInstructions(policy_setup) self.stub.SendPolicyInstructions(policy_setup)
self._running_event.set() self.shutdown_event.clear()
return True return True
@@ -180,7 +179,7 @@ class RobotClient:
def stop(self): def stop(self):
"""Stop the robot client""" """Stop the robot client"""
self._running_event.clear() self.shutdown_event.set()
self.robot.disconnect() self.robot.disconnect()
self.logger.debug("Robot disconnected") self.logger.debug("Robot disconnected")
@@ -208,7 +207,7 @@ class RobotClient:
try: try:
observation_iterator = send_bytes_in_chunks( observation_iterator = send_bytes_in_chunks(
observation_bytes, observation_bytes,
async_inference_pb2.Observation, services_pb2.Observation,
log_prefix="[CLIENT] Observation", log_prefix="[CLIENT] Observation",
silent=True, silent=True,
) )
@@ -283,7 +282,7 @@ class RobotClient:
while self.running: while self.running:
try: try:
# Use StreamActions to get a stream of actions from the server # Use StreamActions to get a stream of actions from the server
actions_chunk = self.stub.GetActions(async_inference_pb2.Empty()) actions_chunk = self.stub.GetActions(services_pb2.Empty())
if len(actions_chunk.data) == 0: if len(actions_chunk.data) == 0:
continue # received `Empty` from server, wait for next call continue # received `Empty` from server, wait for next call

View File

@@ -1,59 +0,0 @@
// fmt: off
// flake8: noqa
// !/usr/bin/env python
// Copyright 2024 The HuggingFace Inc. team.
// All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package async_inference;
// AsyncInference: from Robot perspective
// Robot send observations to & executes action received from a remote Policy server
service AsyncInference {
// Robot -> Policy to share observations with a remote inference server
// Policy -> Robot to share actions predicted for given observations
rpc SendObservations(stream Observation) returns (Empty);
rpc GetActions(Empty) returns (Actions);
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
rpc Ready(Empty) returns (Empty);
rpc Stop(Empty) returns (Empty);
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
TRANSFER_MIDDLE = 2;
TRANSFER_END = 3;
}
// Messages
message Observation {
// sent by Robot, to remote Policy
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
bytes data = 2;
}
message Actions {
// sent by remote Policy, to Robot
bytes data = 1;
}
message PolicySetup {
// sent by Robot to remote server, to init Policy
bytes data = 1;
}
message Empty {}

View File

@@ -1,45 +0,0 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: async_inference.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'async_inference.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdd\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12>\n\nGetActions\x12\x16.async_inference.Empty\x1a\x18.async_inference.Actions\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Empty\x12\x36\n\x04Stop\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=190
_globals['_TRANSFERSTATE']._serialized_end=286
_globals['_OBSERVATION']._serialized_start=42
_globals['_OBSERVATION']._serialized_end=125
_globals['_ACTIONS']._serialized_start=127
_globals['_ACTIONS']._serialized_end=150
_globals['_POLICYSETUP']._serialized_start=152
_globals['_POLICYSETUP']._serialized_end=179
_globals['_EMPTY']._serialized_start=181
_globals['_EMPTY']._serialized_end=188
_globals['_ASYNCINFERENCE']._serialized_start=289
_globals['_ASYNCINFERENCE']._serialized_end=638
# @@protoc_insertion_point(module_scope)

View File

@@ -1,277 +0,0 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
from lerobot.transport import async_inference_pb2 as async__inference__pb2
GRPC_GENERATED_VERSION = '1.71.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class AsyncInferenceStub:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendObservations = channel.stream_unary(
'/async_inference.AsyncInference/SendObservations',
request_serializer=async__inference__pb2.Observation.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.GetActions = channel.unary_unary(
'/async_inference.AsyncInference/GetActions',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Actions.FromString,
_registered_method=True)
self.SendPolicyInstructions = channel.unary_unary(
'/async_inference.AsyncInference/SendPolicyInstructions',
request_serializer=async__inference__pb2.PolicySetup.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/async_inference.AsyncInference/Ready',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.Stop = channel.unary_unary(
'/async_inference.AsyncInference/Stop',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
class AsyncInferenceServicer:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def SendObservations(self, request_iterator, context):
"""Robot -> Policy to share observations with a remote inference server
Policy -> Robot to share actions predicted for given observations
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetActions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendPolicyInstructions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Stop(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AsyncInferenceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendObservations': grpc.stream_unary_rpc_method_handler(
servicer.SendObservations,
request_deserializer=async__inference__pb2.Observation.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'GetActions': grpc.unary_unary_rpc_method_handler(
servicer.GetActions,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Actions.SerializeToString,
),
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
servicer.SendPolicyInstructions,
request_deserializer=async__inference__pb2.PolicySetup.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'Stop': grpc.unary_unary_rpc_method_handler(
servicer.Stop,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'async_inference.AsyncInference', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class AsyncInference:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
@staticmethod
def SendObservations(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/async_inference.AsyncInference/SendObservations',
async__inference__pb2.Observation.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def GetActions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/GetActions',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Actions.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendPolicyInstructions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/SendPolicyInstructions',
async__inference__pb2.PolicySetup.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Ready',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Stop(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Stop',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -33,6 +33,17 @@ service LearnerService {
rpc Ready(Empty) returns (Empty); rpc Ready(Empty) returns (Empty);
} }
// AsyncInference: from Robot perspective
// Robot send observations to & executes action received from a remote Policy server
service AsyncInference {
// Robot -> Policy to share observations with a remote inference server
// Policy -> Robot to share actions predicted for given observations
rpc SendObservations(stream Observation) returns (Empty);
rpc GetActions(Empty) returns (Actions);
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
rpc Ready(Empty) returns (Empty);
}
enum TransferState { enum TransferState {
TRANSFER_UNKNOWN = 0; TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1; TRANSFER_BEGIN = 1;
@@ -56,4 +67,21 @@ message InteractionMessage {
bytes data = 2; bytes data = 2;
} }
// Messages
message Observation {
// sent by Robot, to remote Policy
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
bytes data = 2;
}
message Actions {
// sent by remote Policy, to Robot
bytes data = 1;
}
message PolicySetup {
// sent by Robot to remote server, to init Policy
bytes data = 1;
}
message Empty {} message Empty {}

View File

@@ -1,7 +1,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE # NO CHECKED-IN PROTOBUF GENCODE
# source: lerobot/transport/services.proto # source: lerobot/transport/services.proto
# Protobuf Python Version: 5.29.0 # Protobuf Python Version: 6.31.0
"""Generated protocol buffer code.""" """Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import descriptor_pool as _descriptor_pool
@@ -10,8 +10,8 @@ from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion( _runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC, _runtime_version.Domain.PUBLIC,
5, 6,
29, 31,
0, 0,
'', '',
'lerobot/transport/services.proto' 'lerobot/transport/services.proto'
@@ -23,23 +23,31 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"M\n\x0bObservation\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Empty2\xf5\x01\n\x0e\x41syncInference\x12>\n\x10SendObservations\x12\x16.transport.Observation\x1a\x10.transport.Empty(\x01\x12\x32\n\nGetActions\x12\x10.transport.Empty\x1a\x12.transport.Actions\x12\x42\n\x16SendPolicyInstructions\x12\x16.transport.PolicySetup\x1a\x10.transport.Empty\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
_globals = globals() _globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS: if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=298 _globals['_TRANSFERSTATE']._serialized_start=431
_globals['_TRANSFERSTATE']._serialized_end=394 _globals['_TRANSFERSTATE']._serialized_end=527
_globals['_TRANSITION']._serialized_start=47 _globals['_TRANSITION']._serialized_start=47
_globals['_TRANSITION']._serialized_end=123 _globals['_TRANSITION']._serialized_end=123
_globals['_PARAMETERS']._serialized_start=125 _globals['_PARAMETERS']._serialized_start=125
_globals['_PARAMETERS']._serialized_end=201 _globals['_PARAMETERS']._serialized_end=201
_globals['_INTERACTIONMESSAGE']._serialized_start=203 _globals['_INTERACTIONMESSAGE']._serialized_start=203
_globals['_INTERACTIONMESSAGE']._serialized_end=287 _globals['_INTERACTIONMESSAGE']._serialized_end=287
_globals['_EMPTY']._serialized_start=289 _globals['_OBSERVATION']._serialized_start=289
_globals['_EMPTY']._serialized_end=296 _globals['_OBSERVATION']._serialized_end=366
_globals['_LEARNERSERVICE']._serialized_start=397 _globals['_ACTIONS']._serialized_start=368
_globals['_LEARNERSERVICE']._serialized_end=654 _globals['_ACTIONS']._serialized_end=391
_globals['_POLICYSETUP']._serialized_start=393
_globals['_POLICYSETUP']._serialized_end=420
_globals['_EMPTY']._serialized_start=422
_globals['_EMPTY']._serialized_end=429
_globals['_LEARNERSERVICE']._serialized_start=530
_globals['_LEARNERSERVICE']._serialized_end=787
_globals['_ASYNCINFERENCE']._serialized_start=790
_globals['_ASYNCINFERENCE']._serialized_end=1035
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)

View File

@@ -5,7 +5,7 @@ import warnings
from lerobot.transport import services_pb2 as lerobot_dot_transport_dot_services__pb2 from lerobot.transport import services_pb2 as lerobot_dot_transport_dot_services__pb2
GRPC_GENERATED_VERSION = '1.71.0' GRPC_GENERATED_VERSION = '1.73.1'
GRPC_VERSION = grpc.__version__ GRPC_VERSION = grpc.__version__
_version_not_supported = False _version_not_supported = False
@@ -231,3 +231,212 @@ class LearnerService:
timeout, timeout,
metadata, metadata,
_registered_method=True) _registered_method=True)
class AsyncInferenceStub:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendObservations = channel.stream_unary(
'/transport.AsyncInference/SendObservations',
request_serializer=lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
self.GetActions = channel.unary_unary(
'/transport.AsyncInference/GetActions',
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Actions.FromString,
_registered_method=True)
self.SendPolicyInstructions = channel.unary_unary(
'/transport.AsyncInference/SendPolicyInstructions',
request_serializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/transport.AsyncInference/Ready',
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
_registered_method=True)
class AsyncInferenceServicer:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def SendObservations(self, request_iterator, context):
"""Robot -> Policy to share observations with a remote inference server
Policy -> Robot to share actions predicted for given observations
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetActions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendPolicyInstructions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AsyncInferenceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendObservations': grpc.stream_unary_rpc_method_handler(
servicer.SendObservations,
request_deserializer=lerobot_dot_transport_dot_services__pb2.Observation.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
'GetActions': grpc.unary_unary_rpc_method_handler(
servicer.GetActions,
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Actions.SerializeToString,
),
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
servicer.SendPolicyInstructions,
request_deserializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'transport.AsyncInference', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('transport.AsyncInference', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class AsyncInference:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
@staticmethod
def SendObservations(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/transport.AsyncInference/SendObservations',
lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def GetActions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/transport.AsyncInference/GetActions',
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Actions.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendPolicyInstructions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/transport.AsyncInference/SendPolicyInstructions',
lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/transport.AsyncInference/Ready',
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -19,7 +19,8 @@ import io
import json import json
import logging import logging
import pickle # nosec B403: Safe usage for internal serialization only import pickle # nosec B403: Safe usage for internal serialization only
from multiprocessing import Event, Queue from multiprocessing import Event
from queue import Queue
from typing import Any from typing import Any
import torch import torch
@@ -66,7 +67,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event, log_prefix: str = ""):
bytes_buffer = io.BytesIO() bytes_buffer = io.BytesIO()
step = 0 step = 0
@@ -91,7 +92,10 @@ def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_p
bytes_buffer.write(item.data) bytes_buffer.write(item.data)
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}") logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
queue.put(bytes_buffer.getvalue()) if queue is not None:
queue.put(bytes_buffer.getvalue())
else:
return bytes_buffer.getvalue()
bytes_buffer.seek(0) bytes_buffer.seek(0)
bytes_buffer.truncate(0) bytes_buffer.truncate(0)

View File

@@ -54,8 +54,8 @@ def test_async_inference_e2e(monkeypatch):
from lerobot.scripts.server.policy_server import PolicyServer from lerobot.scripts.server.policy_server import PolicyServer
from lerobot.scripts.server.robot_client import RobotClient from lerobot.scripts.server.robot_client import RobotClient
from lerobot.transport import ( from lerobot.transport import (
async_inference_pb2, # type: ignore services_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore services_pb2_grpc, # type: ignore
) )
from tests.mocks.mock_robot import MockRobotConfig from tests.mocks.mock_robot import MockRobotConfig
@@ -113,13 +113,13 @@ def test_async_inference_e2e(monkeypatch):
# Bypass potentially heavy model loading inside SendPolicyInstructions # Bypass potentially heavy model loading inside SendPolicyInstructions
def _fake_send_policy_instructions(self, request, context): # noqa: N802 def _fake_send_policy_instructions(self, request, context): # noqa: N802
return async_inference_pb2.Empty() return services_pb2.Empty()
monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True) monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True)
# Build gRPC server running a PolicyServer # Build gRPC server running a PolicyServer
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server")) server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server"))
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
# Use the host/port specified in the fixture's config # Use the host/port specified in the fixture's config
server_address = f"{policy_server.config.host}:{policy_server.config.port}" server_address = f"{policy_server.config.host}:{policy_server.config.port}"