[Async Inference] Merge Protos & refactoring (#1480)
* Merge together proto files and refactor Async inference * Fixup for Async inference * Drop not reuqired changes * Fix tests * Drop old async files * Drop chunk_size param * Fix versions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix wrong fix Co-authored-by: Ben Zhang <ben.zhang@uwaterloo.ca> * Fixup --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Ben Zhang <ben.zhang@uwaterloo.ca> Co-authored-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
@@ -95,7 +95,7 @@ dependencies = [
|
|||||||
pygame-dep = ["pygame>=2.5.1"]
|
pygame-dep = ["pygame>=2.5.1"]
|
||||||
placo-dep = ["placo>=0.9.6"]
|
placo-dep = ["placo>=0.9.6"]
|
||||||
transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency
|
transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency
|
||||||
grpcio-dep = ["grpcio==1.71.0"]
|
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
|
||||||
|
|
||||||
# Motors
|
# Motors
|
||||||
feetech = ["feetech-servo-sdk>=1.0.0"]
|
feetech = ["feetech-servo-sdk>=1.0.0"]
|
||||||
@@ -119,14 +119,14 @@ intelrealsense = [
|
|||||||
# Policies
|
# Policies
|
||||||
pi0 = ["lerobot[transformers-dep]"]
|
pi0 = ["lerobot[transformers-dep]"]
|
||||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
|
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
|
||||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "protobuf>=5.29.3", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
|
||||||
# Features
|
# Features
|
||||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"]
|
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"]
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"]
|
docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"]
|
||||||
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "grpcio-tools==1.71.0"]
|
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||||
test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"]
|
test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"]
|
||||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||||
|
|
||||||
|
|||||||
@@ -12,15 +12,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
import logging.handlers
|
import logging.handlers
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Event
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -31,8 +28,6 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
|||||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||||
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
|
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
|
||||||
from lerobot.robots.robot import Robot
|
from lerobot.robots.robot import Robot
|
||||||
from lerobot.transport import async_inference_pb2
|
|
||||||
from lerobot.transport.utils import bytes_buffer_size
|
|
||||||
from lerobot.utils.utils import init_logging
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
Action = torch.Tensor
|
Action = torch.Tensor
|
||||||
@@ -303,84 +298,3 @@ def observations_similar(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
|
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
def send_bytes_in_chunks(
|
|
||||||
buffer: bytes,
|
|
||||||
message_class: Any,
|
|
||||||
log_prefix: str = "",
|
|
||||||
silent: bool = True,
|
|
||||||
chunk_size: int = 3 * 1024 * 1024,
|
|
||||||
):
|
|
||||||
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.send_bytes_in_chunks. Duplication can't be avoided if we
|
|
||||||
# don't use a unique class for messages sent (due to the different transfer states sent). Also, I'd want more control over the
|
|
||||||
# chunk size as I am using it to send image observations.
|
|
||||||
buffer = io.BytesIO(buffer)
|
|
||||||
size_in_bytes = bytes_buffer_size(buffer)
|
|
||||||
|
|
||||||
sent_bytes = 0
|
|
||||||
|
|
||||||
logging_method = logging.info if not silent else logging.debug
|
|
||||||
|
|
||||||
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
|
|
||||||
|
|
||||||
while sent_bytes < size_in_bytes:
|
|
||||||
transfer_state = async_inference_pb2.TransferState.TRANSFER_MIDDLE
|
|
||||||
|
|
||||||
if sent_bytes + chunk_size >= size_in_bytes:
|
|
||||||
transfer_state = async_inference_pb2.TransferState.TRANSFER_END
|
|
||||||
elif sent_bytes == 0:
|
|
||||||
transfer_state = async_inference_pb2.TransferState.TRANSFER_BEGIN
|
|
||||||
|
|
||||||
size_to_read = min(chunk_size, size_in_bytes - sent_bytes)
|
|
||||||
chunk = buffer.read(size_to_read)
|
|
||||||
|
|
||||||
yield message_class(transfer_state=transfer_state, data=chunk)
|
|
||||||
sent_bytes += size_to_read
|
|
||||||
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
|
|
||||||
|
|
||||||
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
|
||||||
|
|
||||||
|
|
||||||
def receive_bytes_in_chunks(
|
|
||||||
iterator, continue_receiving: Event, logger: logging.Logger, log_prefix: str = ""
|
|
||||||
): # type: ignore
|
|
||||||
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.receive_bytes_in_chunks. Duplication can't be avoided if we
|
|
||||||
# don't use a unique class for messages sent (due to the different transfer states sent). Also, on the server side the logic for receiving
|
|
||||||
# is opposite then the HIL-SERL design (my event showcases keeping on running instead of shutdown)
|
|
||||||
bytes_buffer = io.BytesIO()
|
|
||||||
step = 0
|
|
||||||
|
|
||||||
logger.info(f"{log_prefix} Starting receiver")
|
|
||||||
for item in iterator:
|
|
||||||
logger.debug(f"{log_prefix} Received item")
|
|
||||||
if not continue_receiving.is_set():
|
|
||||||
logger.info(f"{log_prefix} Shutting down receiver")
|
|
||||||
return
|
|
||||||
|
|
||||||
if item.transfer_state == async_inference_pb2.TransferState.TRANSFER_BEGIN:
|
|
||||||
bytes_buffer.seek(0)
|
|
||||||
bytes_buffer.truncate(0)
|
|
||||||
bytes_buffer.write(item.data)
|
|
||||||
logger.debug(f"{log_prefix} Received data at step 0")
|
|
||||||
|
|
||||||
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_MIDDLE:
|
|
||||||
bytes_buffer.write(item.data)
|
|
||||||
step += 1
|
|
||||||
logger.debug(f"{log_prefix} Received data at step {step}")
|
|
||||||
|
|
||||||
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_END:
|
|
||||||
bytes_buffer.write(item.data)
|
|
||||||
logger.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
|
||||||
|
|
||||||
complete_bytes = bytes_buffer.getvalue()
|
|
||||||
|
|
||||||
bytes_buffer.seek(0)
|
|
||||||
bytes_buffer.truncate(0)
|
|
||||||
|
|
||||||
logger.debug(f"{log_prefix} Queue updated")
|
|
||||||
return complete_bytes
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}")
|
|
||||||
raise ValueError(f"Received unknown transfer state {item.transfer_state}")
|
|
||||||
|
|||||||
@@ -49,21 +49,21 @@ from lerobot.scripts.server.helpers import (
|
|||||||
get_logger,
|
get_logger,
|
||||||
observations_similar,
|
observations_similar,
|
||||||
raw_observation_to_observation,
|
raw_observation_to_observation,
|
||||||
receive_bytes_in_chunks,
|
|
||||||
)
|
)
|
||||||
from lerobot.transport import (
|
from lerobot.transport import (
|
||||||
async_inference_pb2, # type: ignore
|
services_pb2, # type: ignore
|
||||||
async_inference_pb2_grpc, # type: ignore
|
services_pb2_grpc, # type: ignore
|
||||||
)
|
)
|
||||||
|
from lerobot.transport.utils import receive_bytes_in_chunks
|
||||||
|
|
||||||
|
|
||||||
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||||
prefix = "policy_server"
|
prefix = "policy_server"
|
||||||
logger = get_logger(prefix)
|
logger = get_logger(prefix)
|
||||||
|
|
||||||
def __init__(self, config: PolicyServerConfig):
|
def __init__(self, config: PolicyServerConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._running_event = threading.Event()
|
self.shutdown_event = threading.Event()
|
||||||
|
|
||||||
# FPS measurement
|
# FPS measurement
|
||||||
self.fps_tracker = FPSTracker(target_fps=config.fps)
|
self.fps_tracker = FPSTracker(target_fps=config.fps)
|
||||||
@@ -84,7 +84,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def running(self):
|
def running(self):
|
||||||
return self._running_event.is_set()
|
return not self.shutdown_event.is_set()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def policy_image_features(self):
|
def policy_image_features(self):
|
||||||
@@ -93,7 +93,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
|||||||
def _reset_server(self) -> None:
|
def _reset_server(self) -> None:
|
||||||
"""Flushes server state when new client connects."""
|
"""Flushes server state when new client connects."""
|
||||||
# only running inference on the latest observation received by the server
|
# only running inference on the latest observation received by the server
|
||||||
self._running_event.clear()
|
self.shutdown_event.set()
|
||||||
self.observation_queue = Queue(maxsize=1)
|
self.observation_queue = Queue(maxsize=1)
|
||||||
|
|
||||||
with self._predicted_timesteps_lock:
|
with self._predicted_timesteps_lock:
|
||||||
@@ -103,16 +103,16 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
|||||||
client_id = context.peer()
|
client_id = context.peer()
|
||||||
self.logger.info(f"Client {client_id} connected and ready")
|
self.logger.info(f"Client {client_id} connected and ready")
|
||||||
self._reset_server()
|
self._reset_server()
|
||||||
self._running_event.set()
|
self.shutdown_event.clear()
|
||||||
|
|
||||||
return async_inference_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|
||||||
def SendPolicyInstructions(self, request, context): # noqa: N802
|
def SendPolicyInstructions(self, request, context): # noqa: N802
|
||||||
"""Receive policy instructions from the robot client"""
|
"""Receive policy instructions from the robot client"""
|
||||||
|
|
||||||
if not self.running:
|
if not self.running:
|
||||||
self.logger.warning("Server is not running. Ignoring policy instructions.")
|
self.logger.warning("Server is not running. Ignoring policy instructions.")
|
||||||
return async_inference_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|
||||||
client_id = context.peer()
|
client_id = context.peer()
|
||||||
|
|
||||||
@@ -149,7 +149,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
|||||||
|
|
||||||
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
|
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
|
||||||
|
|
||||||
return async_inference_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|
||||||
def SendObservations(self, request_iterator, context): # noqa: N802
|
def SendObservations(self, request_iterator, context): # noqa: N802
|
||||||
"""Receive observations from the robot client"""
|
"""Receive observations from the robot client"""
|
||||||
@@ -159,7 +159,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
|||||||
receive_time = time.time() # comparing timestamps so need time.time()
|
receive_time = time.time() # comparing timestamps so need time.time()
|
||||||
start_deserialize = time.perf_counter()
|
start_deserialize = time.perf_counter()
|
||||||
received_bytes = receive_bytes_in_chunks(
|
received_bytes = receive_bytes_in_chunks(
|
||||||
request_iterator, self._running_event, self.logger
|
request_iterator, None, self.shutdown_event, self.logger
|
||||||
) # blocking call while looping over request_iterator
|
) # blocking call while looping over request_iterator
|
||||||
timed_observation = pickle.loads(received_bytes) # nosec
|
timed_observation = pickle.loads(received_bytes) # nosec
|
||||||
deserialize_time = time.perf_counter() - start_deserialize
|
deserialize_time = time.perf_counter() - start_deserialize
|
||||||
@@ -190,7 +190,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
|||||||
):
|
):
|
||||||
self.logger.info(f"Observation #{obs_timestep} has been filtered out")
|
self.logger.info(f"Observation #{obs_timestep} has been filtered out")
|
||||||
|
|
||||||
return async_inference_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|
||||||
def GetActions(self, request, context): # noqa: N802
|
def GetActions(self, request, context): # noqa: N802
|
||||||
"""Returns actions to the robot client. Actions are sent as a single
|
"""Returns actions to the robot client. Actions are sent as a single
|
||||||
@@ -218,7 +218,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
|||||||
serialize_time = time.perf_counter() - start_time
|
serialize_time = time.perf_counter() - start_time
|
||||||
|
|
||||||
# Create and return the action chunk
|
# Create and return the action chunk
|
||||||
actions = async_inference_pb2.Actions(data=actions_bytes)
|
actions = services_pb2.Actions(data=actions_bytes)
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Action chunk #{obs.get_timestep()} generated | "
|
f"Action chunk #{obs.get_timestep()} generated | "
|
||||||
@@ -239,12 +239,12 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
|||||||
return actions
|
return actions
|
||||||
|
|
||||||
except Empty: # no observation added to queue in obs_queue_timeout
|
except Empty: # no observation added to queue in obs_queue_timeout
|
||||||
return async_inference_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error in StreamActions: {e}")
|
self.logger.error(f"Error in StreamActions: {e}")
|
||||||
|
|
||||||
return async_inference_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|
||||||
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
|
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
|
||||||
"""Check if the observation is valid to be processed by the policy"""
|
"""Check if the observation is valid to be processed by the policy"""
|
||||||
@@ -388,7 +388,7 @@ def serve(cfg: PolicyServerConfig):
|
|||||||
|
|
||||||
# Setup and start gRPC server
|
# Setup and start gRPC server
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||||||
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||||
server.add_insecure_port(f"{cfg.host}:{cfg.port}")
|
server.add_insecure_port(f"{cfg.host}:{cfg.port}")
|
||||||
|
|
||||||
policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")
|
policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")
|
||||||
|
|||||||
@@ -69,15 +69,14 @@ from lerobot.scripts.server.helpers import (
|
|||||||
TimedObservation,
|
TimedObservation,
|
||||||
get_logger,
|
get_logger,
|
||||||
map_robot_keys_to_lerobot_features,
|
map_robot_keys_to_lerobot_features,
|
||||||
send_bytes_in_chunks,
|
|
||||||
validate_robot_cameras_for_policy,
|
validate_robot_cameras_for_policy,
|
||||||
visualize_action_queue_size,
|
visualize_action_queue_size,
|
||||||
)
|
)
|
||||||
from lerobot.transport import (
|
from lerobot.transport import (
|
||||||
async_inference_pb2, # type: ignore
|
services_pb2, # type: ignore
|
||||||
async_inference_pb2_grpc, # type: ignore
|
services_pb2_grpc, # type: ignore
|
||||||
)
|
)
|
||||||
from lerobot.transport.utils import grpc_channel_options
|
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||||
|
|
||||||
|
|
||||||
class RobotClient:
|
class RobotClient:
|
||||||
@@ -118,10 +117,10 @@ class RobotClient:
|
|||||||
self.channel = grpc.insecure_channel(
|
self.channel = grpc.insecure_channel(
|
||||||
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
|
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
|
||||||
)
|
)
|
||||||
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||||
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
|
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
|
||||||
|
|
||||||
self._running_event = threading.Event()
|
self.shutdown_event = threading.Event()
|
||||||
|
|
||||||
# Initialize client side variables
|
# Initialize client side variables
|
||||||
self.latest_action_lock = threading.Lock()
|
self.latest_action_lock = threading.Lock()
|
||||||
@@ -146,20 +145,20 @@ class RobotClient:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def running(self):
|
def running(self):
|
||||||
return self._running_event.is_set()
|
return not self.shutdown_event.is_set()
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Start the robot client and connect to the policy server"""
|
"""Start the robot client and connect to the policy server"""
|
||||||
try:
|
try:
|
||||||
# client-server handshake
|
# client-server handshake
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
self.stub.Ready(async_inference_pb2.Empty())
|
self.stub.Ready(services_pb2.Empty())
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
|
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
|
||||||
|
|
||||||
# send policy instructions
|
# send policy instructions
|
||||||
policy_config_bytes = pickle.dumps(self.policy_config)
|
policy_config_bytes = pickle.dumps(self.policy_config)
|
||||||
policy_setup = async_inference_pb2.PolicySetup(data=policy_config_bytes)
|
policy_setup = services_pb2.PolicySetup(data=policy_config_bytes)
|
||||||
|
|
||||||
self.logger.info("Sending policy instructions to policy server")
|
self.logger.info("Sending policy instructions to policy server")
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
@@ -170,7 +169,7 @@ class RobotClient:
|
|||||||
|
|
||||||
self.stub.SendPolicyInstructions(policy_setup)
|
self.stub.SendPolicyInstructions(policy_setup)
|
||||||
|
|
||||||
self._running_event.set()
|
self.shutdown_event.clear()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -180,7 +179,7 @@ class RobotClient:
|
|||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Stop the robot client"""
|
"""Stop the robot client"""
|
||||||
self._running_event.clear()
|
self.shutdown_event.set()
|
||||||
|
|
||||||
self.robot.disconnect()
|
self.robot.disconnect()
|
||||||
self.logger.debug("Robot disconnected")
|
self.logger.debug("Robot disconnected")
|
||||||
@@ -208,7 +207,7 @@ class RobotClient:
|
|||||||
try:
|
try:
|
||||||
observation_iterator = send_bytes_in_chunks(
|
observation_iterator = send_bytes_in_chunks(
|
||||||
observation_bytes,
|
observation_bytes,
|
||||||
async_inference_pb2.Observation,
|
services_pb2.Observation,
|
||||||
log_prefix="[CLIENT] Observation",
|
log_prefix="[CLIENT] Observation",
|
||||||
silent=True,
|
silent=True,
|
||||||
)
|
)
|
||||||
@@ -283,7 +282,7 @@ class RobotClient:
|
|||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
# Use StreamActions to get a stream of actions from the server
|
# Use StreamActions to get a stream of actions from the server
|
||||||
actions_chunk = self.stub.GetActions(async_inference_pb2.Empty())
|
actions_chunk = self.stub.GetActions(services_pb2.Empty())
|
||||||
if len(actions_chunk.data) == 0:
|
if len(actions_chunk.data) == 0:
|
||||||
continue # received `Empty` from server, wait for next call
|
continue # received `Empty` from server, wait for next call
|
||||||
|
|
||||||
|
|||||||
@@ -1,59 +0,0 @@
|
|||||||
// fmt: off
|
|
||||||
// flake8: noqa
|
|
||||||
// !/usr/bin/env python
|
|
||||||
|
|
||||||
// Copyright 2024 The HuggingFace Inc. team.
|
|
||||||
// All rights reserved.
|
|
||||||
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
syntax = "proto3";
|
|
||||||
|
|
||||||
package async_inference;
|
|
||||||
|
|
||||||
// AsyncInference: from Robot perspective
|
|
||||||
// Robot send observations to & executes action received from a remote Policy server
|
|
||||||
service AsyncInference {
|
|
||||||
// Robot -> Policy to share observations with a remote inference server
|
|
||||||
// Policy -> Robot to share actions predicted for given observations
|
|
||||||
rpc SendObservations(stream Observation) returns (Empty);
|
|
||||||
rpc GetActions(Empty) returns (Actions);
|
|
||||||
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
|
|
||||||
rpc Ready(Empty) returns (Empty);
|
|
||||||
rpc Stop(Empty) returns (Empty);
|
|
||||||
}
|
|
||||||
|
|
||||||
enum TransferState {
|
|
||||||
TRANSFER_UNKNOWN = 0;
|
|
||||||
TRANSFER_BEGIN = 1;
|
|
||||||
TRANSFER_MIDDLE = 2;
|
|
||||||
TRANSFER_END = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Messages
|
|
||||||
message Observation {
|
|
||||||
// sent by Robot, to remote Policy
|
|
||||||
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
|
|
||||||
bytes data = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message Actions {
|
|
||||||
// sent by remote Policy, to Robot
|
|
||||||
bytes data = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message PolicySetup {
|
|
||||||
// sent by Robot to remote server, to init Policy
|
|
||||||
bytes data = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message Empty {}
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
||||||
# NO CHECKED-IN PROTOBUF GENCODE
|
|
||||||
# source: async_inference.proto
|
|
||||||
# Protobuf Python Version: 5.29.0
|
|
||||||
"""Generated protocol buffer code."""
|
|
||||||
from google.protobuf import descriptor as _descriptor
|
|
||||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
||||||
from google.protobuf import runtime_version as _runtime_version
|
|
||||||
from google.protobuf import symbol_database as _symbol_database
|
|
||||||
from google.protobuf.internal import builder as _builder
|
|
||||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
||||||
_runtime_version.Domain.PUBLIC,
|
|
||||||
5,
|
|
||||||
29,
|
|
||||||
0,
|
|
||||||
'',
|
|
||||||
'async_inference.proto'
|
|
||||||
)
|
|
||||||
# @@protoc_insertion_point(imports)
|
|
||||||
|
|
||||||
_sym_db = _symbol_database.Default()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdd\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12>\n\nGetActions\x12\x16.async_inference.Empty\x1a\x18.async_inference.Actions\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Empty\x12\x36\n\x04Stop\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
|
|
||||||
|
|
||||||
_globals = globals()
|
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
|
|
||||||
if not _descriptor._USE_C_DESCRIPTORS:
|
|
||||||
DESCRIPTOR._loaded_options = None
|
|
||||||
_globals['_TRANSFERSTATE']._serialized_start=190
|
|
||||||
_globals['_TRANSFERSTATE']._serialized_end=286
|
|
||||||
_globals['_OBSERVATION']._serialized_start=42
|
|
||||||
_globals['_OBSERVATION']._serialized_end=125
|
|
||||||
_globals['_ACTIONS']._serialized_start=127
|
|
||||||
_globals['_ACTIONS']._serialized_end=150
|
|
||||||
_globals['_POLICYSETUP']._serialized_start=152
|
|
||||||
_globals['_POLICYSETUP']._serialized_end=179
|
|
||||||
_globals['_EMPTY']._serialized_start=181
|
|
||||||
_globals['_EMPTY']._serialized_end=188
|
|
||||||
_globals['_ASYNCINFERENCE']._serialized_start=289
|
|
||||||
_globals['_ASYNCINFERENCE']._serialized_end=638
|
|
||||||
# @@protoc_insertion_point(module_scope)
|
|
||||||
@@ -1,277 +0,0 @@
|
|||||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
|
||||||
"""Client and server classes corresponding to protobuf-defined services."""
|
|
||||||
import grpc
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from lerobot.transport import async_inference_pb2 as async__inference__pb2
|
|
||||||
|
|
||||||
GRPC_GENERATED_VERSION = '1.71.0'
|
|
||||||
GRPC_VERSION = grpc.__version__
|
|
||||||
_version_not_supported = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from grpc._utilities import first_version_is_lower
|
|
||||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
|
||||||
except ImportError:
|
|
||||||
_version_not_supported = True
|
|
||||||
|
|
||||||
if _version_not_supported:
|
|
||||||
raise RuntimeError(
|
|
||||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
|
||||||
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
|
|
||||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
|
||||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
|
||||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncInferenceStub:
|
|
||||||
"""AsyncInference: from Robot perspective
|
|
||||||
Robot send observations to & executes action received from a remote Policy server
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, channel):
|
|
||||||
"""Constructor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channel: A grpc.Channel.
|
|
||||||
"""
|
|
||||||
self.SendObservations = channel.stream_unary(
|
|
||||||
'/async_inference.AsyncInference/SendObservations',
|
|
||||||
request_serializer=async__inference__pb2.Observation.SerializeToString,
|
|
||||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
|
||||||
_registered_method=True)
|
|
||||||
self.GetActions = channel.unary_unary(
|
|
||||||
'/async_inference.AsyncInference/GetActions',
|
|
||||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
|
||||||
response_deserializer=async__inference__pb2.Actions.FromString,
|
|
||||||
_registered_method=True)
|
|
||||||
self.SendPolicyInstructions = channel.unary_unary(
|
|
||||||
'/async_inference.AsyncInference/SendPolicyInstructions',
|
|
||||||
request_serializer=async__inference__pb2.PolicySetup.SerializeToString,
|
|
||||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
|
||||||
_registered_method=True)
|
|
||||||
self.Ready = channel.unary_unary(
|
|
||||||
'/async_inference.AsyncInference/Ready',
|
|
||||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
|
||||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
|
||||||
_registered_method=True)
|
|
||||||
self.Stop = channel.unary_unary(
|
|
||||||
'/async_inference.AsyncInference/Stop',
|
|
||||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
|
||||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
|
||||||
_registered_method=True)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncInferenceServicer:
|
|
||||||
"""AsyncInference: from Robot perspective
|
|
||||||
Robot send observations to & executes action received from a remote Policy server
|
|
||||||
"""
|
|
||||||
|
|
||||||
def SendObservations(self, request_iterator, context):
|
|
||||||
"""Robot -> Policy to share observations with a remote inference server
|
|
||||||
Policy -> Robot to share actions predicted for given observations
|
|
||||||
"""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def GetActions(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def SendPolicyInstructions(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Ready(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
def Stop(self, request, context):
|
|
||||||
"""Missing associated documentation comment in .proto file."""
|
|
||||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
||||||
context.set_details('Method not implemented!')
|
|
||||||
raise NotImplementedError('Method not implemented!')
|
|
||||||
|
|
||||||
|
|
||||||
def add_AsyncInferenceServicer_to_server(servicer, server):
|
|
||||||
rpc_method_handlers = {
|
|
||||||
'SendObservations': grpc.stream_unary_rpc_method_handler(
|
|
||||||
servicer.SendObservations,
|
|
||||||
request_deserializer=async__inference__pb2.Observation.FromString,
|
|
||||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
|
||||||
),
|
|
||||||
'GetActions': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.GetActions,
|
|
||||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
|
||||||
response_serializer=async__inference__pb2.Actions.SerializeToString,
|
|
||||||
),
|
|
||||||
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.SendPolicyInstructions,
|
|
||||||
request_deserializer=async__inference__pb2.PolicySetup.FromString,
|
|
||||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
|
||||||
),
|
|
||||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Ready,
|
|
||||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
|
||||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
|
||||||
),
|
|
||||||
'Stop': grpc.unary_unary_rpc_method_handler(
|
|
||||||
servicer.Stop,
|
|
||||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
|
||||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
generic_handler = grpc.method_handlers_generic_handler(
|
|
||||||
'async_inference.AsyncInference', rpc_method_handlers)
|
|
||||||
server.add_generic_rpc_handlers((generic_handler,))
|
|
||||||
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
|
|
||||||
|
|
||||||
|
|
||||||
# This class is part of an EXPERIMENTAL API.
|
|
||||||
class AsyncInference:
|
|
||||||
"""AsyncInference: from Robot perspective
|
|
||||||
Robot send observations to & executes action received from a remote Policy server
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def SendObservations(request_iterator,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.stream_unary(
|
|
||||||
request_iterator,
|
|
||||||
target,
|
|
||||||
'/async_inference.AsyncInference/SendObservations',
|
|
||||||
async__inference__pb2.Observation.SerializeToString,
|
|
||||||
async__inference__pb2.Empty.FromString,
|
|
||||||
options,
|
|
||||||
channel_credentials,
|
|
||||||
insecure,
|
|
||||||
call_credentials,
|
|
||||||
compression,
|
|
||||||
wait_for_ready,
|
|
||||||
timeout,
|
|
||||||
metadata,
|
|
||||||
_registered_method=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def GetActions(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(
|
|
||||||
request,
|
|
||||||
target,
|
|
||||||
'/async_inference.AsyncInference/GetActions',
|
|
||||||
async__inference__pb2.Empty.SerializeToString,
|
|
||||||
async__inference__pb2.Actions.FromString,
|
|
||||||
options,
|
|
||||||
channel_credentials,
|
|
||||||
insecure,
|
|
||||||
call_credentials,
|
|
||||||
compression,
|
|
||||||
wait_for_ready,
|
|
||||||
timeout,
|
|
||||||
metadata,
|
|
||||||
_registered_method=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def SendPolicyInstructions(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(
|
|
||||||
request,
|
|
||||||
target,
|
|
||||||
'/async_inference.AsyncInference/SendPolicyInstructions',
|
|
||||||
async__inference__pb2.PolicySetup.SerializeToString,
|
|
||||||
async__inference__pb2.Empty.FromString,
|
|
||||||
options,
|
|
||||||
channel_credentials,
|
|
||||||
insecure,
|
|
||||||
call_credentials,
|
|
||||||
compression,
|
|
||||||
wait_for_ready,
|
|
||||||
timeout,
|
|
||||||
metadata,
|
|
||||||
_registered_method=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Ready(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(
|
|
||||||
request,
|
|
||||||
target,
|
|
||||||
'/async_inference.AsyncInference/Ready',
|
|
||||||
async__inference__pb2.Empty.SerializeToString,
|
|
||||||
async__inference__pb2.Empty.FromString,
|
|
||||||
options,
|
|
||||||
channel_credentials,
|
|
||||||
insecure,
|
|
||||||
call_credentials,
|
|
||||||
compression,
|
|
||||||
wait_for_ready,
|
|
||||||
timeout,
|
|
||||||
metadata,
|
|
||||||
_registered_method=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def Stop(request,
|
|
||||||
target,
|
|
||||||
options=(),
|
|
||||||
channel_credentials=None,
|
|
||||||
call_credentials=None,
|
|
||||||
insecure=False,
|
|
||||||
compression=None,
|
|
||||||
wait_for_ready=None,
|
|
||||||
timeout=None,
|
|
||||||
metadata=None):
|
|
||||||
return grpc.experimental.unary_unary(
|
|
||||||
request,
|
|
||||||
target,
|
|
||||||
'/async_inference.AsyncInference/Stop',
|
|
||||||
async__inference__pb2.Empty.SerializeToString,
|
|
||||||
async__inference__pb2.Empty.FromString,
|
|
||||||
options,
|
|
||||||
channel_credentials,
|
|
||||||
insecure,
|
|
||||||
call_credentials,
|
|
||||||
compression,
|
|
||||||
wait_for_ready,
|
|
||||||
timeout,
|
|
||||||
metadata,
|
|
||||||
_registered_method=True)
|
|
||||||
@@ -33,6 +33,17 @@ service LearnerService {
|
|||||||
rpc Ready(Empty) returns (Empty);
|
rpc Ready(Empty) returns (Empty);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AsyncInference: from Robot perspective
|
||||||
|
// Robot send observations to & executes action received from a remote Policy server
|
||||||
|
service AsyncInference {
|
||||||
|
// Robot -> Policy to share observations with a remote inference server
|
||||||
|
// Policy -> Robot to share actions predicted for given observations
|
||||||
|
rpc SendObservations(stream Observation) returns (Empty);
|
||||||
|
rpc GetActions(Empty) returns (Actions);
|
||||||
|
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
|
||||||
|
rpc Ready(Empty) returns (Empty);
|
||||||
|
}
|
||||||
|
|
||||||
enum TransferState {
|
enum TransferState {
|
||||||
TRANSFER_UNKNOWN = 0;
|
TRANSFER_UNKNOWN = 0;
|
||||||
TRANSFER_BEGIN = 1;
|
TRANSFER_BEGIN = 1;
|
||||||
@@ -56,4 +67,21 @@ message InteractionMessage {
|
|||||||
bytes data = 2;
|
bytes data = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Messages
|
||||||
|
message Observation {
|
||||||
|
// sent by Robot, to remote Policy
|
||||||
|
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
|
||||||
|
bytes data = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Actions {
|
||||||
|
// sent by remote Policy, to Robot
|
||||||
|
bytes data = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message PolicySetup {
|
||||||
|
// sent by Robot to remote server, to init Policy
|
||||||
|
bytes data = 1;
|
||||||
|
}
|
||||||
|
|
||||||
message Empty {}
|
message Empty {}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
# NO CHECKED-IN PROTOBUF GENCODE
|
# NO CHECKED-IN PROTOBUF GENCODE
|
||||||
# source: lerobot/transport/services.proto
|
# source: lerobot/transport/services.proto
|
||||||
# Protobuf Python Version: 5.29.0
|
# Protobuf Python Version: 6.31.0
|
||||||
"""Generated protocol buffer code."""
|
"""Generated protocol buffer code."""
|
||||||
from google.protobuf import descriptor as _descriptor
|
from google.protobuf import descriptor as _descriptor
|
||||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
@@ -10,8 +10,8 @@ from google.protobuf import symbol_database as _symbol_database
|
|||||||
from google.protobuf.internal import builder as _builder
|
from google.protobuf.internal import builder as _builder
|
||||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||||
_runtime_version.Domain.PUBLIC,
|
_runtime_version.Domain.PUBLIC,
|
||||||
5,
|
6,
|
||||||
29,
|
31,
|
||||||
0,
|
0,
|
||||||
'',
|
'',
|
||||||
'lerobot/transport/services.proto'
|
'lerobot/transport/services.proto'
|
||||||
@@ -23,23 +23,31 @@ _sym_db = _symbol_database.Default()
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"M\n\x0bObservation\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Empty2\xf5\x01\n\x0e\x41syncInference\x12>\n\x10SendObservations\x12\x16.transport.Observation\x1a\x10.transport.Empty(\x01\x12\x32\n\nGetActions\x12\x10.transport.Empty\x1a\x12.transport.Actions\x12\x42\n\x16SendPolicyInstructions\x12\x16.transport.PolicySetup\x1a\x10.transport.Empty\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3')
|
||||||
|
|
||||||
_globals = globals()
|
_globals = globals()
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals)
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals)
|
||||||
if not _descriptor._USE_C_DESCRIPTORS:
|
if not _descriptor._USE_C_DESCRIPTORS:
|
||||||
DESCRIPTOR._loaded_options = None
|
DESCRIPTOR._loaded_options = None
|
||||||
_globals['_TRANSFERSTATE']._serialized_start=298
|
_globals['_TRANSFERSTATE']._serialized_start=431
|
||||||
_globals['_TRANSFERSTATE']._serialized_end=394
|
_globals['_TRANSFERSTATE']._serialized_end=527
|
||||||
_globals['_TRANSITION']._serialized_start=47
|
_globals['_TRANSITION']._serialized_start=47
|
||||||
_globals['_TRANSITION']._serialized_end=123
|
_globals['_TRANSITION']._serialized_end=123
|
||||||
_globals['_PARAMETERS']._serialized_start=125
|
_globals['_PARAMETERS']._serialized_start=125
|
||||||
_globals['_PARAMETERS']._serialized_end=201
|
_globals['_PARAMETERS']._serialized_end=201
|
||||||
_globals['_INTERACTIONMESSAGE']._serialized_start=203
|
_globals['_INTERACTIONMESSAGE']._serialized_start=203
|
||||||
_globals['_INTERACTIONMESSAGE']._serialized_end=287
|
_globals['_INTERACTIONMESSAGE']._serialized_end=287
|
||||||
_globals['_EMPTY']._serialized_start=289
|
_globals['_OBSERVATION']._serialized_start=289
|
||||||
_globals['_EMPTY']._serialized_end=296
|
_globals['_OBSERVATION']._serialized_end=366
|
||||||
_globals['_LEARNERSERVICE']._serialized_start=397
|
_globals['_ACTIONS']._serialized_start=368
|
||||||
_globals['_LEARNERSERVICE']._serialized_end=654
|
_globals['_ACTIONS']._serialized_end=391
|
||||||
|
_globals['_POLICYSETUP']._serialized_start=393
|
||||||
|
_globals['_POLICYSETUP']._serialized_end=420
|
||||||
|
_globals['_EMPTY']._serialized_start=422
|
||||||
|
_globals['_EMPTY']._serialized_end=429
|
||||||
|
_globals['_LEARNERSERVICE']._serialized_start=530
|
||||||
|
_globals['_LEARNERSERVICE']._serialized_end=787
|
||||||
|
_globals['_ASYNCINFERENCE']._serialized_start=790
|
||||||
|
_globals['_ASYNCINFERENCE']._serialized_end=1035
|
||||||
# @@protoc_insertion_point(module_scope)
|
# @@protoc_insertion_point(module_scope)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import warnings
|
|||||||
|
|
||||||
from lerobot.transport import services_pb2 as lerobot_dot_transport_dot_services__pb2
|
from lerobot.transport import services_pb2 as lerobot_dot_transport_dot_services__pb2
|
||||||
|
|
||||||
GRPC_GENERATED_VERSION = '1.71.0'
|
GRPC_GENERATED_VERSION = '1.73.1'
|
||||||
GRPC_VERSION = grpc.__version__
|
GRPC_VERSION = grpc.__version__
|
||||||
_version_not_supported = False
|
_version_not_supported = False
|
||||||
|
|
||||||
@@ -231,3 +231,212 @@ class LearnerService:
|
|||||||
timeout,
|
timeout,
|
||||||
metadata,
|
metadata,
|
||||||
_registered_method=True)
|
_registered_method=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncInferenceStub:
|
||||||
|
"""AsyncInference: from Robot perspective
|
||||||
|
Robot send observations to & executes action received from a remote Policy server
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: A grpc.Channel.
|
||||||
|
"""
|
||||||
|
self.SendObservations = channel.stream_unary(
|
||||||
|
'/transport.AsyncInference/SendObservations',
|
||||||
|
request_serializer=lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString,
|
||||||
|
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
self.GetActions = channel.unary_unary(
|
||||||
|
'/transport.AsyncInference/GetActions',
|
||||||
|
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||||
|
response_deserializer=lerobot_dot_transport_dot_services__pb2.Actions.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
self.SendPolicyInstructions = channel.unary_unary(
|
||||||
|
'/transport.AsyncInference/SendPolicyInstructions',
|
||||||
|
request_serializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString,
|
||||||
|
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
self.Ready = channel.unary_unary(
|
||||||
|
'/transport.AsyncInference/Ready',
|
||||||
|
request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||||
|
response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncInferenceServicer:
|
||||||
|
"""AsyncInference: from Robot perspective
|
||||||
|
Robot send observations to & executes action received from a remote Policy server
|
||||||
|
"""
|
||||||
|
|
||||||
|
def SendObservations(self, request_iterator, context):
|
||||||
|
"""Robot -> Policy to share observations with a remote inference server
|
||||||
|
Policy -> Robot to share actions predicted for given observations
|
||||||
|
"""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def GetActions(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def SendPolicyInstructions(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def Ready(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
|
||||||
|
def add_AsyncInferenceServicer_to_server(servicer, server):
|
||||||
|
rpc_method_handlers = {
|
||||||
|
'SendObservations': grpc.stream_unary_rpc_method_handler(
|
||||||
|
servicer.SendObservations,
|
||||||
|
request_deserializer=lerobot_dot_transport_dot_services__pb2.Observation.FromString,
|
||||||
|
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||||
|
),
|
||||||
|
'GetActions': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.GetActions,
|
||||||
|
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
|
response_serializer=lerobot_dot_transport_dot_services__pb2.Actions.SerializeToString,
|
||||||
|
),
|
||||||
|
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.SendPolicyInstructions,
|
||||||
|
request_deserializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.FromString,
|
||||||
|
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||||
|
),
|
||||||
|
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.Ready,
|
||||||
|
request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
|
response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
|
'transport.AsyncInference', rpc_method_handlers)
|
||||||
|
server.add_generic_rpc_handlers((generic_handler,))
|
||||||
|
server.add_registered_method_handlers('transport.AsyncInference', rpc_method_handlers)
|
||||||
|
|
||||||
|
|
||||||
|
# This class is part of an EXPERIMENTAL API.
|
||||||
|
class AsyncInference:
|
||||||
|
"""AsyncInference: from Robot perspective
|
||||||
|
Robot send observations to & executes action received from a remote Policy server
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def SendObservations(request_iterator,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.stream_unary(
|
||||||
|
request_iterator,
|
||||||
|
target,
|
||||||
|
'/transport.AsyncInference/SendObservations',
|
||||||
|
lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString,
|
||||||
|
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def GetActions(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/transport.AsyncInference/GetActions',
|
||||||
|
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||||
|
lerobot_dot_transport_dot_services__pb2.Actions.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def SendPolicyInstructions(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/transport.AsyncInference/SendPolicyInstructions',
|
||||||
|
lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString,
|
||||||
|
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Ready(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/transport.AsyncInference/Ready',
|
||||||
|
lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString,
|
||||||
|
lerobot_dot_transport_dot_services__pb2.Empty.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|||||||
@@ -19,7 +19,8 @@ import io
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import pickle # nosec B403: Safe usage for internal serialization only
|
import pickle # nosec B403: Safe usage for internal serialization only
|
||||||
from multiprocessing import Event, Queue
|
from multiprocessing import Event
|
||||||
|
from queue import Queue
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -66,7 +67,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = ""
|
|||||||
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
||||||
|
|
||||||
|
|
||||||
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore
|
def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event, log_prefix: str = ""):
|
||||||
bytes_buffer = io.BytesIO()
|
bytes_buffer = io.BytesIO()
|
||||||
step = 0
|
step = 0
|
||||||
|
|
||||||
@@ -91,7 +92,10 @@ def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_p
|
|||||||
bytes_buffer.write(item.data)
|
bytes_buffer.write(item.data)
|
||||||
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
||||||
|
|
||||||
queue.put(bytes_buffer.getvalue())
|
if queue is not None:
|
||||||
|
queue.put(bytes_buffer.getvalue())
|
||||||
|
else:
|
||||||
|
return bytes_buffer.getvalue()
|
||||||
|
|
||||||
bytes_buffer.seek(0)
|
bytes_buffer.seek(0)
|
||||||
bytes_buffer.truncate(0)
|
bytes_buffer.truncate(0)
|
||||||
|
|||||||
@@ -54,8 +54,8 @@ def test_async_inference_e2e(monkeypatch):
|
|||||||
from lerobot.scripts.server.policy_server import PolicyServer
|
from lerobot.scripts.server.policy_server import PolicyServer
|
||||||
from lerobot.scripts.server.robot_client import RobotClient
|
from lerobot.scripts.server.robot_client import RobotClient
|
||||||
from lerobot.transport import (
|
from lerobot.transport import (
|
||||||
async_inference_pb2, # type: ignore
|
services_pb2, # type: ignore
|
||||||
async_inference_pb2_grpc, # type: ignore
|
services_pb2_grpc, # type: ignore
|
||||||
)
|
)
|
||||||
from tests.mocks.mock_robot import MockRobotConfig
|
from tests.mocks.mock_robot import MockRobotConfig
|
||||||
|
|
||||||
@@ -113,13 +113,13 @@ def test_async_inference_e2e(monkeypatch):
|
|||||||
|
|
||||||
# Bypass potentially heavy model loading inside SendPolicyInstructions
|
# Bypass potentially heavy model loading inside SendPolicyInstructions
|
||||||
def _fake_send_policy_instructions(self, request, context): # noqa: N802
|
def _fake_send_policy_instructions(self, request, context): # noqa: N802
|
||||||
return async_inference_pb2.Empty()
|
return services_pb2.Empty()
|
||||||
|
|
||||||
monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True)
|
monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True)
|
||||||
|
|
||||||
# Build gRPC server running a PolicyServer
|
# Build gRPC server running a PolicyServer
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server"))
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server"))
|
||||||
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||||
|
|
||||||
# Use the host/port specified in the fixture's config
|
# Use the host/port specified in the fixture's config
|
||||||
server_address = f"{policy_server.config.host}:{policy_server.config.port}"
|
server_address = f"{policy_server.config.host}:{policy_server.config.port}"
|
||||||
|
|||||||
Reference in New Issue
Block a user