[Async Inference] Merge Protos & refactoring (#1480)
* Merge together proto files and refactor Async inference * Fixup for Async inference * Drop not reuqired changes * Fix tests * Drop old async files * Drop chunk_size param * Fix versions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix wrong fix Co-authored-by: Ben Zhang <ben.zhang@uwaterloo.ca> * Fixup --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Ben Zhang <ben.zhang@uwaterloo.ca> Co-authored-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
@@ -95,7 +95,7 @@ dependencies = [
|
||||
pygame-dep = ["pygame>=2.5.1"]
|
||||
placo-dep = ["placo>=0.9.6"]
|
||||
transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency
|
||||
grpcio-dep = ["grpcio==1.71.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0"]
|
||||
@@ -119,14 +119,14 @@ intelrealsense = [
|
||||
# Policies
|
||||
pi0 = ["lerobot[transformers-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "protobuf>=5.29.3", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"]
|
||||
|
||||
# Development
|
||||
docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"]
|
||||
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "grpcio-tools==1.71.0"]
|
||||
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||
test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||
|
||||
|
||||
@@ -12,15 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from threading import Event
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -31,8 +28,6 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.transport import async_inference_pb2
|
||||
from lerobot.transport.utils import bytes_buffer_size
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
Action = torch.Tensor
|
||||
@@ -303,84 +298,3 @@ def observations_similar(
|
||||
)
|
||||
|
||||
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
|
||||
|
||||
|
||||
def send_bytes_in_chunks(
|
||||
buffer: bytes,
|
||||
message_class: Any,
|
||||
log_prefix: str = "",
|
||||
silent: bool = True,
|
||||
chunk_size: int = 3 * 1024 * 1024,
|
||||
):
|
||||
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.send_bytes_in_chunks. Duplication can't be avoided if we
|
||||
# don't use a unique class for messages sent (due to the different transfer states sent). Also, I'd want more control over the
|
||||
# chunk size as I am using it to send image observations.
|
||||
buffer = io.BytesIO(buffer)
|
||||
size_in_bytes = bytes_buffer_size(buffer)
|
||||
|
||||
sent_bytes = 0
|
||||
|
||||
logging_method = logging.info if not silent else logging.debug
|
||||
|
||||
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
|
||||
|
||||
while sent_bytes < size_in_bytes:
|
||||
transfer_state = async_inference_pb2.TransferState.TRANSFER_MIDDLE
|
||||
|
||||
if sent_bytes + chunk_size >= size_in_bytes:
|
||||
transfer_state = async_inference_pb2.TransferState.TRANSFER_END
|
||||
elif sent_bytes == 0:
|
||||
transfer_state = async_inference_pb2.TransferState.TRANSFER_BEGIN
|
||||
|
||||
size_to_read = min(chunk_size, size_in_bytes - sent_bytes)
|
||||
chunk = buffer.read(size_to_read)
|
||||
|
||||
yield message_class(transfer_state=transfer_state, data=chunk)
|
||||
sent_bytes += size_to_read
|
||||
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
|
||||
|
||||
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
||||
|
||||
|
||||
def receive_bytes_in_chunks(
|
||||
iterator, continue_receiving: Event, logger: logging.Logger, log_prefix: str = ""
|
||||
): # type: ignore
|
||||
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.receive_bytes_in_chunks. Duplication can't be avoided if we
|
||||
# don't use a unique class for messages sent (due to the different transfer states sent). Also, on the server side the logic for receiving
|
||||
# is opposite then the HIL-SERL design (my event showcases keeping on running instead of shutdown)
|
||||
bytes_buffer = io.BytesIO()
|
||||
step = 0
|
||||
|
||||
logger.info(f"{log_prefix} Starting receiver")
|
||||
for item in iterator:
|
||||
logger.debug(f"{log_prefix} Received item")
|
||||
if not continue_receiving.is_set():
|
||||
logger.info(f"{log_prefix} Shutting down receiver")
|
||||
return
|
||||
|
||||
if item.transfer_state == async_inference_pb2.TransferState.TRANSFER_BEGIN:
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
bytes_buffer.write(item.data)
|
||||
logger.debug(f"{log_prefix} Received data at step 0")
|
||||
|
||||
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_MIDDLE:
|
||||
bytes_buffer.write(item.data)
|
||||
step += 1
|
||||
logger.debug(f"{log_prefix} Received data at step {step}")
|
||||
|
||||
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_END:
|
||||
bytes_buffer.write(item.data)
|
||||
logger.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
||||
|
||||
complete_bytes = bytes_buffer.getvalue()
|
||||
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
|
||||
logger.debug(f"{log_prefix} Queue updated")
|
||||
return complete_bytes
|
||||
|
||||
else:
|
||||
logger.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}")
|
||||
raise ValueError(f"Received unknown transfer state {item.transfer_state}")
|
||||
|
||||
@@ -49,21 +49,21 @@ from lerobot.scripts.server.helpers import (
|
||||
get_logger,
|
||||
observations_similar,
|
||||
raw_observation_to_observation,
|
||||
receive_bytes_in_chunks,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
async_inference_pb2, # type: ignore
|
||||
async_inference_pb2_grpc, # type: ignore
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks
|
||||
|
||||
|
||||
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
prefix = "policy_server"
|
||||
logger = get_logger(prefix)
|
||||
|
||||
def __init__(self, config: PolicyServerConfig):
|
||||
self.config = config
|
||||
self._running_event = threading.Event()
|
||||
self.shutdown_event = threading.Event()
|
||||
|
||||
# FPS measurement
|
||||
self.fps_tracker = FPSTracker(target_fps=config.fps)
|
||||
@@ -84,7 +84,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return self._running_event.is_set()
|
||||
return not self.shutdown_event.is_set()
|
||||
|
||||
@property
|
||||
def policy_image_features(self):
|
||||
@@ -93,7 +93,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
def _reset_server(self) -> None:
|
||||
"""Flushes server state when new client connects."""
|
||||
# only running inference on the latest observation received by the server
|
||||
self._running_event.clear()
|
||||
self.shutdown_event.set()
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
|
||||
with self._predicted_timesteps_lock:
|
||||
@@ -103,16 +103,16 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
client_id = context.peer()
|
||||
self.logger.info(f"Client {client_id} connected and ready")
|
||||
self._reset_server()
|
||||
self._running_event.set()
|
||||
self.shutdown_event.clear()
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendPolicyInstructions(self, request, context): # noqa: N802
|
||||
"""Receive policy instructions from the robot client"""
|
||||
|
||||
if not self.running:
|
||||
self.logger.warning("Server is not running. Ignoring policy instructions.")
|
||||
return async_inference_pb2.Empty()
|
||||
return services_pb2.Empty()
|
||||
|
||||
client_id = context.peer()
|
||||
|
||||
@@ -149,7 +149,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
|
||||
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendObservations(self, request_iterator, context): # noqa: N802
|
||||
"""Receive observations from the robot client"""
|
||||
@@ -159,7 +159,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
receive_time = time.time() # comparing timestamps so need time.time()
|
||||
start_deserialize = time.perf_counter()
|
||||
received_bytes = receive_bytes_in_chunks(
|
||||
request_iterator, self._running_event, self.logger
|
||||
request_iterator, None, self.shutdown_event, self.logger
|
||||
) # blocking call while looping over request_iterator
|
||||
timed_observation = pickle.loads(received_bytes) # nosec
|
||||
deserialize_time = time.perf_counter() - start_deserialize
|
||||
@@ -190,7 +190,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
):
|
||||
self.logger.info(f"Observation #{obs_timestep} has been filtered out")
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
return services_pb2.Empty()
|
||||
|
||||
def GetActions(self, request, context): # noqa: N802
|
||||
"""Returns actions to the robot client. Actions are sent as a single
|
||||
@@ -218,7 +218,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
serialize_time = time.perf_counter() - start_time
|
||||
|
||||
# Create and return the action chunk
|
||||
actions = async_inference_pb2.Actions(data=actions_bytes)
|
||||
actions = services_pb2.Actions(data=actions_bytes)
|
||||
|
||||
self.logger.info(
|
||||
f"Action chunk #{obs.get_timestep()} generated | "
|
||||
@@ -239,12 +239,12 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
return actions
|
||||
|
||||
except Empty: # no observation added to queue in obs_queue_timeout
|
||||
return async_inference_pb2.Empty()
|
||||
return services_pb2.Empty()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in StreamActions: {e}")
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
return services_pb2.Empty()
|
||||
|
||||
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
|
||||
"""Check if the observation is valid to be processed by the policy"""
|
||||
@@ -388,7 +388,7 @@ def serve(cfg: PolicyServerConfig):
|
||||
|
||||
# Setup and start gRPC server
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||||
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
server.add_insecure_port(f"{cfg.host}:{cfg.port}")
|
||||
|
||||
policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")
|
||||
|
||||
@@ -69,15 +69,14 @@ from lerobot.scripts.server.helpers import (
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
map_robot_keys_to_lerobot_features,
|
||||
send_bytes_in_chunks,
|
||||
validate_robot_cameras_for_policy,
|
||||
visualize_action_queue_size,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
async_inference_pb2, # type: ignore
|
||||
async_inference_pb2_grpc, # type: ignore
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import grpc_channel_options
|
||||
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||
|
||||
|
||||
class RobotClient:
|
||||
@@ -118,10 +117,10 @@ class RobotClient:
|
||||
self.channel = grpc.insecure_channel(
|
||||
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
|
||||
)
|
||||
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||
self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
|
||||
|
||||
self._running_event = threading.Event()
|
||||
self.shutdown_event = threading.Event()
|
||||
|
||||
# Initialize client side variables
|
||||
self.latest_action_lock = threading.Lock()
|
||||
@@ -146,20 +145,20 @@ class RobotClient:
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return self._running_event.is_set()
|
||||
return not self.shutdown_event.is_set()
|
||||
|
||||
def start(self):
|
||||
"""Start the robot client and connect to the policy server"""
|
||||
try:
|
||||
# client-server handshake
|
||||
start_time = time.perf_counter()
|
||||
self.stub.Ready(async_inference_pb2.Empty())
|
||||
self.stub.Ready(services_pb2.Empty())
|
||||
end_time = time.perf_counter()
|
||||
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
|
||||
|
||||
# send policy instructions
|
||||
policy_config_bytes = pickle.dumps(self.policy_config)
|
||||
policy_setup = async_inference_pb2.PolicySetup(data=policy_config_bytes)
|
||||
policy_setup = services_pb2.PolicySetup(data=policy_config_bytes)
|
||||
|
||||
self.logger.info("Sending policy instructions to policy server")
|
||||
self.logger.debug(
|
||||
@@ -170,7 +169,7 @@ class RobotClient:
|
||||
|
||||
self.stub.SendPolicyInstructions(policy_setup)
|
||||
|
||||
self._running_event.set()
|
||||
self.shutdown_event.clear()
|
||||
|
||||
return True
|
||||
|
||||
@@ -180,7 +179,7 @@ class RobotClient:
|
||||
|
||||
def stop(self):
|
||||
"""Stop the robot client"""
|
||||
self._running_event.clear()
|
||||
self.shutdown_event.set()
|
||||
|
||||
self.robot.disconnect()
|
||||
self.logger.debug("Robot disconnected")
|
||||
@@ -208,7 +207,7 @@ class RobotClient:
|
||||
try:
|
||||
observation_iterator = send_bytes_in_chunks(
|
||||
observation_bytes,
|
||||
async_inference_pb2.Observation,
|
||||
services_pb2.Observation,
|
||||
log_prefix="[CLIENT] Observation",
|
||||
silent=True,
|
||||
)
|
||||
@@ -283,7 +282,7 @@ class RobotClient:
|
||||
while self.running:
|
||||
try:
|
||||
# Use StreamActions to get a stream of actions from the server
|
||||
actions_chunk = self.stub.GetActions(async_inference_pb2.Empty())
|
||||
actions_chunk = self.stub.GetActions(services_pb2.Empty())
|
||||
if len(actions_chunk.data) == 0:
|
||||
continue # received `Empty` from server, wait for next call
|
||||
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
// fmt: off
|
||||
// flake8: noqa
|
||||
// !/usr/bin/env python
|
||||
|
||||
// Copyright 2024 The HuggingFace Inc. team.
|
||||
// All rights reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
syntax = "proto3";
|
||||
|
||||
package async_inference;
|
||||
|
||||
// AsyncInference: from Robot perspective
|
||||
// Robot send observations to & executes action received from a remote Policy server
|
||||
service AsyncInference {
|
||||
// Robot -> Policy to share observations with a remote inference server
|
||||
// Policy -> Robot to share actions predicted for given observations
|
||||
rpc SendObservations(stream Observation) returns (Empty);
|
||||
rpc GetActions(Empty) returns (Actions);
|
||||
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
rpc Stop(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
enum TransferState {
|
||||
TRANSFER_UNKNOWN = 0;
|
||||
TRANSFER_BEGIN = 1;
|
||||
TRANSFER_MIDDLE = 2;
|
||||
TRANSFER_END = 3;
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Observation {
|
||||
// sent by Robot, to remote Policy
|
||||
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Actions {
|
||||
// sent by remote Policy, to Robot
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message PolicySetup {
|
||||
// sent by Robot to remote server, to init Policy
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
@@ -1,45 +0,0 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: async_inference.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'async_inference.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdd\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12>\n\nGetActions\x12\x16.async_inference.Empty\x1a\x18.async_inference.Actions\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Empty\x12\x36\n\x04Stop\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TRANSFERSTATE']._serialized_start=190
|
||||
_globals['_TRANSFERSTATE']._serialized_end=286
|
||||
_globals['_OBSERVATION']._serialized_start=42
|
||||
_globals['_OBSERVATION']._serialized_end=125
|
||||
_globals['_ACTIONS']._serialized_start=127
|
||||
_globals['_ACTIONS']._serialized_end=150
|
||||
_globals['_POLICYSETUP']._serialized_start=152
|
||||
_globals['_POLICYSETUP']._serialized_end=179
|
||||
_globals['_EMPTY']._serialized_start=181
|
||||
_globals['_EMPTY']._serialized_end=188
|
||||
_globals['_ASYNCINFERENCE']._serialized_start=289
|
||||
_globals['_ASYNCINFERENCE']._serialized_end=638
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -1,277 +0,0 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from lerobot.transport import async_inference_pb2 as async__inference__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.71.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class AsyncInferenceStub:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendObservations = channel.stream_unary(
|
||||
'/async_inference.AsyncInference/SendObservations',
|
||||
request_serializer=async__inference__pb2.Observation.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.GetActions = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/GetActions',
|
||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Actions.FromString,
|
||||
_registered_method=True)
|
||||
self.SendPolicyInstructions = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/SendPolicyInstructions',
|
||||
request_serializer=async__inference__pb2.PolicySetup.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/Ready',
|
||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Stop = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/Stop',
|
||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class AsyncInferenceServicer:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def SendObservations(self, request_iterator, context):
|
||||
"""Robot -> Policy to share observations with a remote inference server
|
||||
Policy -> Robot to share actions predicted for given observations
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetActions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendPolicyInstructions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Ready(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Stop(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_AsyncInferenceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendObservations': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendObservations,
|
||||
request_deserializer=async__inference__pb2.Observation.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'GetActions': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetActions,
|
||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
||||
response_serializer=async__inference__pb2.Actions.SerializeToString,
|
||||
),
|
||||
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendPolicyInstructions,
|
||||
request_deserializer=async__inference__pb2.PolicySetup.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Stop': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Stop,
|
||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'async_inference.AsyncInference', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class AsyncInference:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def SendObservations(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/async_inference.AsyncInference/SendObservations',
|
||||
async__inference__pb2.Observation.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetActions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/GetActions',
|
||||
async__inference__pb2.Empty.SerializeToString,
|
||||
async__inference__pb2.Actions.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendPolicyInstructions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/SendPolicyInstructions',
|
||||
async__inference__pb2.PolicySetup.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Ready(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/Ready',
|
||||
async__inference__pb2.Empty.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Stop(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/Stop',
|
||||
async__inference__pb2.Empty.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@@ -33,6 +33,17 @@ service LearnerService {
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
// AsyncInference: from Robot perspective
|
||||
// Robot send observations to & executes action received from a remote Policy server
|
||||
service AsyncInference {
|
||||
// Robot -> Policy to share observations with a remote inference server
|
||||
// Policy -> Robot to share actions predicted for given observations
|
||||
rpc SendObservations(stream Observation) returns (Empty);
|
||||
rpc GetActions(Empty) returns (Actions);
|
||||
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
enum TransferState {
|
||||
TRANSFER_UNKNOWN = 0;
|
||||
TRANSFER_BEGIN = 1;
|
||||
@@ -56,4 +67,21 @@ message InteractionMessage {
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Observation {
|
||||
// sent by Robot, to remote Policy
|
||||
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Actions {
|
||||
// sent by remote Policy, to Robot
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message PolicySetup {
|
||||
// sent by Robot to remote server, to init Policy
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: lerobot/transport/services.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
# Protobuf Python Version: 6.31.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
@@ -10,8 +10,8 @@ from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
6,
|
||||
31,
|
||||
0,
|
||||
'',
|
||||
'lerobot/transport/services.proto'
|
||||
@@ -23,23 +23,31 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"M\n\x0bObservation\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Empty2\xf5\x01\n\x0e\x41syncInference\x12>\n\x10SendObservations\x12\x16.transport.Observation\x1a\x10.transport.Empty(\x01\x12\x32\n\nGetActions\x12\x10.transport.Empty\x1a\x12.transport.Actions\x12\x42\n\x16SendPolicyInstructions\x12\x16.transport.PolicySetup\x1a\x10.transport.Empty\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TRANSFERSTATE']._serialized_start=298
|
||||
_globals['_TRANSFERSTATE']._serialized_end=394
|
||||
_globals['_TRANSFERSTATE']._serialized_start=431
|
||||
_globals['_TRANSFERSTATE']._serialized_end=527
|
||||
_globals['_TRANSITION']._serialized_start=47
|
||||
_globals['_TRANSITION']._serialized_end=123
|
||||
_globals['_PARAMETERS']._serialized_start=125
|
||||
_globals['_PARAMETERS']._serialized_end=201
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_start=203
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_end=287
|
||||
_globals['_EMPTY']._serialized_start=289
|
||||
_globals['_EMPTY']._serialized_end=296
|
||||
_globals['_LEARNERSERVICE']._serialized_start=397
|
||||
_globals['_LEARNERSERVICE']._serialized_end=654
|
||||
_globals['_OBSERVATION']._serialized_start=289
|
||||
_globals['_OBSERVATION']._serialized_end=366
|
||||
_globals['_ACTIONS']._serialized_start=368
|
||||
_globals['_ACTIONS']._serialized_end=391
|
||||
_globals['_POLICYSETUP']._serialized_start=393
|
||||
_globals['_POLICYSETUP']._serialized_end=420
|
||||
_globals['_EMPTY']._serialized_start=422
|
||||
_globals['_EMPTY']._serialized_end=429
|
||||
_globals['_LEARNERSERVICE']._serialized_start=530
|
||||
_globals['_LEARNERSERVICE']._serialized_end=787
|
||||
_globals['_ASYNCINFERENCE']._serialized_start=790
|
||||
_globals['_ASYNCINFERENCE']._serialized_end=1035
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@@ -5,7 +5,7 @@ import warnings
|
||||
|
||||
from lerobot.transport import services_pb2 as lerobot_dot_transport_dot_services__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.71.0'
|
||||
GRPC_GENERATED_VERSION = '1.73.1'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
@@ -231,3 +231,212 @@ class LearnerService:
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class AsyncInferenceStub:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendObservations = channel.stream_unary(
|
||||
'/transport.AsyncInference/SendObservations',
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.GetActions = channel.unary_unary(
|
||||
'/transport.AsyncInference/GetActions',
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Actions.FromString,
|
||||
_registered_method=True)
|
||||
self.SendPolicyInstructions = channel.unary_unary(
|
||||
'/transport.AsyncInference/SendPolicyInstructions',
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/transport.AsyncInference/Ready',
|
||||
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class AsyncInferenceServicer:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def SendObservations(self, request_iterator, context):
|
||||
"""Robot -> Policy to share observations with a remote inference server
|
||||
Policy -> Robot to share actions predicted for given observations
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetActions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendPolicyInstructions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Ready(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_AsyncInferenceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendObservations': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendObservations,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.Observation.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'GetActions': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetActions,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Actions.SerializeToString,
|
||||
),
|
||||
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendPolicyInstructions,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'transport.AsyncInference', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('transport.AsyncInference', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class AsyncInference:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def SendObservations(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/transport.AsyncInference/SendObservations',
|
||||
lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetActions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/transport.AsyncInference/GetActions',
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Actions.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendPolicyInstructions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/transport.AsyncInference/SendPolicyInstructions',
|
||||
lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Ready(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/transport.AsyncInference/Ready',
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@@ -19,7 +19,8 @@ import io
|
||||
import json
|
||||
import logging
|
||||
import pickle # nosec B403: Safe usage for internal serialization only
|
||||
from multiprocessing import Event, Queue
|
||||
from multiprocessing import Event
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -66,7 +67,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
|
||||
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
||||
|
||||
|
||||
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore
|
||||
def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event, log_prefix: str = ""):
|
||||
bytes_buffer = io.BytesIO()
|
||||
step = 0
|
||||
|
||||
@@ -91,7 +92,10 @@ def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_p
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
||||
|
||||
queue.put(bytes_buffer.getvalue())
|
||||
if queue is not None:
|
||||
queue.put(bytes_buffer.getvalue())
|
||||
else:
|
||||
return bytes_buffer.getvalue()
|
||||
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
|
||||
@@ -54,8 +54,8 @@ def test_async_inference_e2e(monkeypatch):
|
||||
from lerobot.scripts.server.policy_server import PolicyServer
|
||||
from lerobot.scripts.server.robot_client import RobotClient
|
||||
from lerobot.transport import (
|
||||
async_inference_pb2, # type: ignore
|
||||
async_inference_pb2_grpc, # type: ignore
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
|
||||
@@ -113,13 +113,13 @@ def test_async_inference_e2e(monkeypatch):
|
||||
|
||||
# Bypass potentially heavy model loading inside SendPolicyInstructions
|
||||
def _fake_send_policy_instructions(self, request, context): # noqa: N802
|
||||
return async_inference_pb2.Empty()
|
||||
return services_pb2.Empty()
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True)
|
||||
|
||||
# Build gRPC server running a PolicyServer
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server"))
|
||||
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
|
||||
# Use the host/port specified in the fixture's config
|
||||
server_address = f"{policy_server.config.host}:{policy_server.config.port}"
|
||||
|
||||
Reference in New Issue
Block a user