Add Async Inference (#1196)

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Francesco Capuano
2025-07-10 10:39:11 +02:00
committed by GitHub
parent ce2b9724bf
commit 30c161006d
15 changed files with 3266 additions and 1 deletions

View File

@@ -0,0 +1,197 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Callable
import torch
from lerobot.robots.config import RobotConfig
from lerobot.scripts.server.constants import (
DEFAULT_FPS,
DEFAULT_INFERENCE_LATENCY,
DEFAULT_OBS_QUEUE_TIMEOUT,
)
# Aggregate function registry for CLI usage
AGGREGATE_FUNCTIONS = {
"weighted_average": lambda old, new: 0.3 * old + 0.7 * new,
"latest_only": lambda old, new: new,
"average": lambda old, new: 0.5 * old + 0.5 * new,
"conservative": lambda old, new: 0.7 * old + 0.3 * new,
}
def get_aggregate_function(name: str) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
"""Get aggregate function by name from registry."""
if name not in AGGREGATE_FUNCTIONS:
available = list(AGGREGATE_FUNCTIONS.keys())
raise ValueError(f"Unknown aggregate function '{name}'. Available: {available}")
return AGGREGATE_FUNCTIONS[name]
@dataclass
class PolicyServerConfig:
"""Configuration for PolicyServer.
This class defines all configurable parameters for the PolicyServer,
including networking settings and action chunking specifications.
"""
# Networking configuration
host: str = field(default="localhost", metadata={"help": "Host address to bind the server to"})
port: int = field(default=8080, metadata={"help": "Port number to bind the server to"})
# Timing configuration
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
inference_latency: float = field(
default=DEFAULT_INFERENCE_LATENCY, metadata={"help": "Target inference latency in seconds"}
)
obs_queue_timeout: float = field(
default=DEFAULT_OBS_QUEUE_TIMEOUT, metadata={"help": "Timeout for observation queue in seconds"}
)
def __post_init__(self):
"""Validate configuration after initialization."""
if self.port < 1 or self.port > 65535:
raise ValueError(f"Port must be between 1 and 65535, got {self.port}")
if self.environment_dt <= 0:
raise ValueError(f"environment_dt must be positive, got {self.environment_dt}")
if self.inference_latency < 0:
raise ValueError(f"inference_latency must be non-negative, got {self.inference_latency}")
if self.obs_queue_timeout < 0:
raise ValueError(f"obs_queue_timeout must be non-negative, got {self.obs_queue_timeout}")
@classmethod
def from_dict(cls, config_dict: dict) -> "PolicyServerConfig":
"""Create a PolicyServerConfig from a dictionary."""
return cls(**config_dict)
@property
def environment_dt(self) -> float:
"""Environment time step, in seconds"""
return 1 / self.fps
def to_dict(self) -> dict:
"""Convert the configuration to a dictionary."""
return {
"host": self.host,
"port": self.port,
"fps": self.fps,
"environment_dt": self.environment_dt,
"inference_latency": self.inference_latency,
}
@dataclass
class RobotClientConfig:
"""Configuration for RobotClient.
This class defines all configurable parameters for the RobotClient,
including network connection, policy settings, and control behavior.
"""
# Policy configuration
policy_type: str = field(metadata={"help": "Type of policy to use"})
pretrained_name_or_path: str = field(metadata={"help": "Pretrained model name or path"})
# Robot configuration (for CLI usage - robot instance will be created from this)
robot: RobotConfig = field(metadata={"help": "Robot configuration"})
# Policies typically output K actions at max, but we can use less to avoid wasting bandwidth (as actions
# would be aggregated on the client side anyway, depending on the value of `chunk_size_threshold`)
actions_per_chunk: int = field(metadata={"help": "Number of actions per chunk"})
# Task instruction for the robot to execute (e.g., 'fold my tshirt')
task: str = field(default="", metadata={"help": "Task instruction for the robot to execute"})
# Network configuration
server_address: str = field(default="localhost:8080", metadata={"help": "Server address to connect to"})
# Device configuration
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
# Control behavior configuration
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
# Aggregate function configuration (CLI-compatible)
aggregate_fn_name: str = field(
default="weighted_average",
metadata={"help": f"Name of aggregate function to use. Options: {list(AGGREGATE_FUNCTIONS.keys())}"},
)
# Debug configuration
debug_visualize_queue_size: bool = field(
default=False, metadata={"help": "Visualize the action queue size"}
)
# Verification configuration
verify_robot_cameras: bool = field(
default=True, metadata={"help": "Verify that the robot cameras match the policy cameras"}
)
@property
def environment_dt(self) -> float:
"""Environment time step, in seconds"""
return 1 / self.fps
def __post_init__(self):
"""Validate configuration after initialization."""
if not self.server_address:
raise ValueError("server_address cannot be empty")
if not self.policy_type:
raise ValueError("policy_type cannot be empty")
if not self.pretrained_name_or_path:
raise ValueError("pretrained_name_or_path cannot be empty")
if not self.policy_device:
raise ValueError("policy_device cannot be empty")
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
if self.fps <= 0:
raise ValueError(f"fps must be positive, got {self.fps}")
if self.actions_per_chunk <= 0:
raise ValueError(f"actions_per_chunk must be positive, got {self.actions_per_chunk}")
self.aggregate_fn = get_aggregate_function(self.aggregate_fn_name)
@classmethod
def from_dict(cls, config_dict: dict) -> "RobotClientConfig":
"""Create a RobotClientConfig from a dictionary."""
return cls(**config_dict)
def to_dict(self) -> dict:
"""Convert the configuration to a dictionary."""
return {
"server_address": self.server_address,
"policy_type": self.policy_type,
"pretrained_name_or_path": self.pretrained_name_or_path,
"policy_device": self.policy_device,
"chunk_size_threshold": self.chunk_size_threshold,
"fps": self.fps,
"actions_per_chunk": self.actions_per_chunk,
"task": self.task,
"debug_visualize_queue_size": self.debug_visualize_queue_size,
"aggregate_fn_name": self.aggregate_fn_name,
}

View File

@@ -0,0 +1,29 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Client side: The environment evolves with a time resolution equal to 1/fps"""
DEFAULT_FPS = 30
"""Server side: Running inference on (at most) 1/fps"""
DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
"""Server side: Timeout for observation queue in seconds"""
DEFAULT_OBS_QUEUE_TIMEOUT = 2
# All action chunking policies
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"]
# TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"]

View File

@@ -0,0 +1,386 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import logging.handlers
import os
import time
from dataclasses import dataclass
from pathlib import Path
from threading import Event
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_IMAGES, OBS_STATE
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
from lerobot.robots.robot import Robot
from lerobot.transport import async_inference_pb2
from lerobot.transport.utils import bytes_buffer_size
from lerobot.utils.utils import init_logging
Action = torch.Tensor
ActionChunk = torch.Tensor
# observation as received from the robot
RawObservation = dict[str, torch.Tensor]
# observation as those recorded in LeRobot dataset (keys are different)
LeRobotObservation = dict[str, torch.Tensor]
# observation, ready for policy inference (image keys resized)
Observation = dict[str, torch.Tensor]
def visualize_action_queue_size(action_queue_size: list[int]) -> None:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.set_title("Action Queue Size Over Time")
ax.set_xlabel("Environment steps")
ax.set_ylabel("Action Queue Size")
ax.set_ylim(0, max(action_queue_size) * 1.1)
ax.grid(True, alpha=0.3)
ax.plot(range(len(action_queue_size)), action_queue_size)
plt.show()
def validate_robot_cameras_for_policy(
lerobot_observation_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature]
) -> None:
image_keys = list(filter(is_image_key, lerobot_observation_features))
assert set(image_keys) == set(policy_image_features.keys()), (
f"Policy image features must match robot cameras! Received {list(policy_image_features.keys())} != {image_keys}"
)
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
return hw_to_dataset_features(robot.observation_features, "observation", use_video=False)
def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)
def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int, int, int]) -> torch.tensor:
assert image.ndim == 3, f"Image must be (C, H, W)! Received {image.shape}"
# (H, W, C) -> (C, H, W) for resizing from robot obsevation resolution to policy image resolution
image = image.permute(2, 0, 1)
dims = (resize_dims[1], resize_dims[2])
# Add batch dimension for interpolate: (C, H, W) -> (1, C, H, W)
image_batched = image.unsqueeze(0)
# Interpolate and remove batch dimension: (1, C, H, W) -> (C, H, W)
resized = torch.nn.functional.interpolate(image_batched, size=dims, mode="bilinear", align_corners=False)
return resized.squeeze(0)
def raw_observation_to_observation(
raw_observation: RawObservation,
lerobot_features: dict[str, dict],
policy_image_features: dict[str, PolicyFeature],
device: str,
) -> Observation:
observation = {}
observation = prepare_raw_observation(raw_observation, lerobot_features, policy_image_features)
for k, v in observation.items():
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations
if "image" in k:
# Policy expects images in shape (B, C, H, W)
observation[k] = prepare_image(v).unsqueeze(0).to(device)
else:
observation[k] = v.to(device)
else:
observation[k] = v
return observation
def prepare_image(image: torch.Tensor) -> torch.Tensor:
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
image = image.type(torch.float32) / 255
image = image.contiguous()
return image
def extract_state_from_raw_observation(
lerobot_obs: RawObservation,
) -> torch.Tensor:
"""Extract the state from a raw observation."""
state = torch.tensor(lerobot_obs[OBS_STATE])
if state.ndim == 1:
state = state.unsqueeze(0)
return state
def extract_images_from_raw_observation(
lerobot_obs: RawObservation,
camera_key: str,
) -> dict[str, torch.Tensor]:
"""Extract the images from a raw observation."""
return torch.tensor(lerobot_obs[camera_key])
def make_lerobot_observation(
robot_obs: RawObservation,
lerobot_features: dict[str, dict],
) -> LeRobotObservation:
"""Make a lerobot observation from a raw observation."""
return build_dataset_frame(lerobot_features, robot_obs, prefix="observation")
def prepare_raw_observation(
robot_obs: RawObservation,
lerobot_features: dict[str, dict],
policy_image_features: dict[str, PolicyFeature],
) -> Observation:
"""Matches keys from the raw robot_obs dict to the keys expected by a given policy (passed as
policy_image_features)."""
# 1. {motor.pos1:value1, motor.pos2:value2, ..., laptop:np.ndarray} ->
# -> {observation.state:[value1,value2,...], observation.images.laptop:np.ndarray}
lerobot_obs = make_lerobot_observation(robot_obs, lerobot_features)
# 2. Greps all observation.images.<> keys
image_keys = list(filter(is_image_key, lerobot_obs))
# state's shape is expected as (B, state_dim)
state_dict = {OBS_STATE: extract_state_from_raw_observation(lerobot_obs)}
image_dict = {
image_k: extract_images_from_raw_observation(lerobot_obs, image_k) for image_k in image_keys
}
# Turns the image features to (C, H, W) with H, W matching the policy image features.
# This reduces the resolution of the images
image_dict = {
key: resize_robot_observation_image(torch.tensor(lerobot_obs[key]), policy_image_features[key].shape)
for key in image_keys
}
if "task" in robot_obs:
state_dict["task"] = robot_obs["task"]
return {**state_dict, **image_dict}
def get_logger(name: str, log_to_file: bool = True) -> logging.Logger:
"""
Get a logger using the standardized logging setup from utils.py.
Args:
name: Logger name (e.g., 'policy_server', 'robot_client')
log_to_file: Whether to also log to a file
Returns:
Configured logger instance
"""
# Create logs directory if logging to file
if log_to_file:
os.makedirs("logs", exist_ok=True)
log_file = Path(f"logs/{name}_{int(time.time())}.log")
else:
log_file = None
# Initialize the standardized logging
init_logging(log_file=log_file, display_pid=False)
# Return a named logger
return logging.getLogger(name)
@dataclass
class TimedData:
"""A data object with timestamp and timestep information.
Args:
timestamp: Unix timestamp relative to data's creation.
data: The actual data to wrap a timestamp around.
timestep: The timestep of the data.
"""
timestamp: float
timestep: int
def get_timestamp(self):
return self.timestamp
def get_timestep(self):
return self.timestep
@dataclass
class TimedAction(TimedData):
action: Action
def get_action(self):
return self.action
@dataclass
class TimedObservation(TimedData):
observation: RawObservation
must_go: bool = False
def get_observation(self):
return self.observation
@dataclass
class FPSTracker:
"""Utility class to track FPS metrics over time."""
target_fps: float
first_timestamp: float = None
total_obs_count: int = 0
def calculate_fps_metrics(self, current_timestamp: float) -> dict[str, float]:
"""Calculate average FPS vs target"""
self.total_obs_count += 1
# Initialize first observation time
if self.first_timestamp is None:
self.first_timestamp = current_timestamp
# Calculate overall average FPS (since start)
total_duration = current_timestamp - self.first_timestamp
avg_fps = (self.total_obs_count - 1) / total_duration if total_duration > 1e-6 else 0.0
return {"avg_fps": avg_fps, "target_fps": self.target_fps}
def reset(self):
"""Reset the FPS tracker state"""
self.first_timestamp = None
self.total_obs_count = 0
@dataclass
class RemotePolicyConfig:
policy_type: str
pretrained_name_or_path: str
lerobot_features: dict[str, PolicyFeature]
actions_per_chunk: int
device: str = "cpu"
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
"""Check if two observation states are similar, under a tolerance threshold"""
return bool(torch.linalg.norm(obs1_state - obs2_state) < atol)
def observations_similar(
obs1: TimedObservation, obs2: TimedObservation, lerobot_features: dict[str, dict], atol: float = 1
) -> bool:
"""Check if two observations are similar, under a tolerance threshold. Measures distance between
observations as the difference in joint-space between the two observations.
NOTE(fracapuano): This is a very simple check, and it is enough for the current use case.
An immediate next step is to use (fast) perceptual difference metrics comparing some camera views,
to surpass this joint-space similarity check.
"""
obs1_state = extract_state_from_raw_observation(
make_lerobot_observation(obs1.get_observation(), lerobot_features)
)
obs2_state = extract_state_from_raw_observation(
make_lerobot_observation(obs2.get_observation(), lerobot_features)
)
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
def send_bytes_in_chunks(
buffer: bytes,
message_class: Any,
log_prefix: str = "",
silent: bool = True,
chunk_size: int = 3 * 1024 * 1024,
):
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.send_bytes_in_chunks. Duplication can't be avoided if we
# don't use a unique class for messages sent (due to the different transfer states sent). Also, I'd want more control over the
# chunk size as I am using it to send image observations.
buffer = io.BytesIO(buffer)
size_in_bytes = bytes_buffer_size(buffer)
sent_bytes = 0
logging_method = logging.info if not silent else logging.debug
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = async_inference_pb2.TransferState.TRANSFER_MIDDLE
if sent_bytes + chunk_size >= size_in_bytes:
transfer_state = async_inference_pb2.TransferState.TRANSFER_END
elif sent_bytes == 0:
transfer_state = async_inference_pb2.TransferState.TRANSFER_BEGIN
size_to_read = min(chunk_size, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read)
yield message_class(transfer_state=transfer_state, data=chunk)
sent_bytes += size_to_read
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(
iterator, continue_receiving: Event, logger: logging.Logger, log_prefix: str = ""
): # type: ignore
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.receive_bytes_in_chunks. Duplication can't be avoided if we
# don't use a unique class for messages sent (due to the different transfer states sent). Also, on the server side the logic for receiving
# is opposite then the HIL-SERL design (my event showcases keeping on running instead of shutdown)
bytes_buffer = io.BytesIO()
step = 0
logger.info(f"{log_prefix} Starting receiver")
for item in iterator:
logger.debug(f"{log_prefix} Received item")
if not continue_receiving.is_set():
logger.info(f"{log_prefix} Shutting down receiver")
return
if item.transfer_state == async_inference_pb2.TransferState.TRANSFER_BEGIN:
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
bytes_buffer.write(item.data)
logger.debug(f"{log_prefix} Received data at step 0")
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_MIDDLE:
bytes_buffer.write(item.data)
step += 1
logger.debug(f"{log_prefix} Received data at step {step}")
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_END:
bytes_buffer.write(item.data)
logger.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
complete_bytes = bytes_buffer.getvalue()
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
logger.debug(f"{log_prefix} Queue updated")
return complete_bytes
else:
logger.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}")
raise ValueError(f"Received unknown transfer state {item.transfer_state}")

View File

@@ -0,0 +1,403 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example:
```shell
python src/lerobot/scripts/server/policy_server.py \
--host=127.0.0.1 \
--port=8080 \
--fps=30 \
--inference_latency=0.033 \
--obs_queue_timeout=1
```
"""
import logging
import pickle # nosec
import threading
import time
from concurrent import futures
from dataclasses import asdict
from pprint import pformat
from queue import Empty, Queue
import draccus
import grpc
import torch
from lerobot.policies.factory import get_policy_class
from lerobot.scripts.server.configs import PolicyServerConfig
from lerobot.scripts.server.constants import SUPPORTED_POLICIES
from lerobot.scripts.server.helpers import (
FPSTracker,
Observation,
RemotePolicyConfig,
TimedAction,
TimedObservation,
get_logger,
observations_similar,
raw_observation_to_observation,
receive_bytes_in_chunks,
)
from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
)
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
prefix = "policy_server"
logger = get_logger(prefix)
def __init__(self, config: PolicyServerConfig):
self.config = config
self._running_event = threading.Event()
# FPS measurement
self.fps_tracker = FPSTracker(target_fps=config.fps)
self.observation_queue = Queue(maxsize=1)
self._predicted_timesteps_lock = threading.Lock()
self._predicted_timesteps = set()
self.last_processed_obs = None
# Attributes will be set by SendPolicyInstructions
self.device = None
self.policy_type = None
self.lerobot_features = None
self.actions_per_chunk = None
self.policy = None
@property
def running(self):
return self._running_event.is_set()
@property
def policy_image_features(self):
return self.policy.config.image_features
def _reset_server(self) -> None:
"""Flushes server state when new client connects."""
# only running inference on the latest observation received by the server
self._running_event.clear()
self.observation_queue = Queue(maxsize=1)
with self._predicted_timesteps_lock:
self._predicted_timesteps = set()
def Ready(self, request, context): # noqa: N802
client_id = context.peer()
self.logger.info(f"Client {client_id} connected and ready")
self._reset_server()
self._running_event.set()
return async_inference_pb2.Empty()
def SendPolicyInstructions(self, request, context): # noqa: N802
"""Receive policy instructions from the robot client"""
if not self.running:
self.logger.warning("Server is not running. Ignoring policy instructions.")
return async_inference_pb2.Empty()
client_id = context.peer()
policy_specs = pickle.loads(request.data) # nosec
if not isinstance(policy_specs, RemotePolicyConfig):
raise TypeError(f"Policy specs must be a RemotePolicyConfig. Got {type(policy_specs)}")
if policy_specs.policy_type not in SUPPORTED_POLICIES:
raise ValueError(
f"Policy type {policy_specs.policy_type} not supported. "
f"Supported policies: {SUPPORTED_POLICIES}"
)
self.logger.info(
f"Receiving policy instructions from {client_id} | "
f"Policy type: {policy_specs.policy_type} | "
f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | "
f"Actions per chunk: {policy_specs.actions_per_chunk} | "
f"Device: {policy_specs.device}"
)
self.device = policy_specs.device
self.policy_type = policy_specs.policy_type # act, pi0, etc.
self.lerobot_features = policy_specs.lerobot_features
self.actions_per_chunk = policy_specs.actions_per_chunk
policy_class = get_policy_class(self.policy_type)
start = time.perf_counter()
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
self.policy.to(self.device)
end = time.perf_counter()
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
return async_inference_pb2.Empty()
def SendObservations(self, request_iterator, context): # noqa: N802
"""Receive observations from the robot client"""
client_id = context.peer()
self.logger.debug(f"Receiving observations from {client_id}")
receive_time = time.time() # comparing timestamps so need time.time()
start_deserialize = time.perf_counter()
received_bytes = receive_bytes_in_chunks(
request_iterator, self._running_event, self.logger
) # blocking call while looping over request_iterator
timed_observation = pickle.loads(received_bytes) # nosec
deserialize_time = time.perf_counter() - start_deserialize
self.logger.debug(f"Received observation #{timed_observation.get_timestep()}")
obs_timestep = timed_observation.get_timestep()
obs_timestamp = timed_observation.get_timestamp()
# Calculate FPS metrics
fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp)
self.logger.info(
f"Received observation #{obs_timestep} | "
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client
f"Target: {fps_metrics['target_fps']:.2f} | "
f"One-way latency: {(receive_time - obs_timestamp) * 1000:.2f}ms"
)
self.logger.debug(
f"Server timestamp: {receive_time:.6f} | "
f"Client timestamp: {obs_timestamp:.6f} | "
f"Deserialization time: {deserialize_time:.6f}s"
)
if not self._enqueue_observation(
timed_observation # wrapping a RawObservation
):
self.logger.info(f"Observation #{obs_timestep} has been filtered out")
return async_inference_pb2.Empty()
def GetActions(self, request, context): # noqa: N802
"""Returns actions to the robot client. Actions are sent as a single
chunk, containing multiple actions."""
client_id = context.peer()
self.logger.debug(f"Client {client_id} connected for action streaming")
# Generate action based on the most recent observation and its timestep
try:
getactions_starts = time.perf_counter()
obs = self.observation_queue.get(timeout=self.config.obs_queue_timeout)
self.logger.info(
f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})"
)
with self._predicted_timesteps_lock:
self._predicted_timesteps.add(obs.get_timestep())
start_time = time.perf_counter()
action_chunk = self._predict_action_chunk(obs)
inference_time = time.perf_counter() - start_time
start_time = time.perf_counter()
actions_bytes = pickle.dumps(action_chunk) # nosec
serialize_time = time.perf_counter() - start_time
# Create and return the action chunk
actions = async_inference_pb2.Actions(data=actions_bytes)
self.logger.info(
f"Action chunk #{obs.get_timestep()} generated | "
f"Total time: {(inference_time + serialize_time) * 1000:.2f}ms"
)
self.logger.debug(
f"Action chunk #{obs.get_timestep()} generated | "
f"Inference time: {inference_time:.2f}s |"
f"Serialize time: {serialize_time:.2f}s |"
f"Total time: {inference_time + serialize_time:.2f}s"
)
time.sleep(
max(0, self.config.inference_latency - max(0, time.perf_counter() - getactions_starts))
) # sleep controls inference latency
return actions
except Empty: # no observation added to queue in obs_queue_timeout
return async_inference_pb2.Empty()
except Exception as e:
self.logger.error(f"Error in StreamActions: {e}")
return async_inference_pb2.Empty()
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
"""Check if the observation is valid to be processed by the policy"""
with self._predicted_timesteps_lock:
predicted_timesteps = self._predicted_timesteps
if obs.get_timestep() in predicted_timesteps:
self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!")
return False
elif observations_similar(obs, previous_obs, lerobot_features=self.lerobot_features):
self.logger.debug(
f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!"
)
return False
else:
return True
def _enqueue_observation(self, obs: TimedObservation) -> bool:
"""Enqueue an observation if it must go through processing, otherwise skip it.
Observations not in queue are never run through the policy network"""
if (
obs.must_go
or self.last_processed_obs is None
or self._obs_sanity_checks(obs, self.last_processed_obs)
):
last_obs = self.last_processed_obs.get_timestep() if self.last_processed_obs else "None"
self.logger.debug(
f"Enqueuing observation. Must go: {obs.must_go} | Last processed obs: {last_obs}"
)
# If queue is full, get the old observation to make room
if self.observation_queue.full():
# pops from queue
_ = self.observation_queue.get_nowait()
self.logger.debug("Observation queue was full, removed oldest observation")
# Now put the new observation (never blocks as queue is non-full here)
self.observation_queue.put(obs)
return True
return False
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
"""Turn a chunk of actions into a list of TimedAction instances,
with the first action corresponding to t_0 and the rest corresponding to
t_0 + i*environment_dt for i in range(len(action_chunk))
"""
return [
TimedAction(timestamp=t_0 + i * self.config.environment_dt, timestep=i_0 + i, action=action)
for i, action in enumerate(action_chunk)
]
def _prepare_observation(self, observation_t: TimedObservation) -> Observation:
"""
Prepare observation, ready for policy inference.
E.g.: To keep observation sampling rate high (and network packet tiny) we send int8 [0,255] images from the
client and then convert them to float32 [0,1] images here, before running inference.
"""
# RawObservation from robot.get_observation() - wrong keys, wrong dtype, wrong image shape
observation: Observation = raw_observation_to_observation(
observation_t.get_observation(),
self.lerobot_features,
self.policy_image_features,
self.device,
)
# processed Observation - right keys, right dtype, right image shape
return observation
def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
"""Get an action chunk from the policy. The chunk contains only"""
chunk = self.policy.predict_action_chunk(observation)
if chunk.ndim != 3:
chunk = chunk.unsqueeze(0) # adding batch dimension, now shape is (B, chunk_size, action_dim)
return chunk[:, : self.actions_per_chunk, :] + torch.randn_like(chunk[:, : self.actions_per_chunk, :])
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
"""Predict an action chunk based on an observation"""
inference_starts = time.perf_counter()
"""1. Prepare observation"""
start_time = time.perf_counter()
observation = self._prepare_observation(observation_t)
preprocessing_time = time.perf_counter() - start_time
self.last_processed_obs: TimedObservation = observation_t
"""2. Get action chunk"""
start_time = time.perf_counter()
action_tensor = self._get_action_chunk(observation)
inference_time = time.perf_counter() - start_time
"""3. Post-inference processing"""
start_time = time.perf_counter()
# Move to CPU before serializing
action_tensor = action_tensor.cpu().squeeze(0)
action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
)
postprocessing_time = time.perf_counter() - start_time
inference_stops = time.perf_counter()
self.logger.info(
f"Observation {observation_t.get_timestep()} |"
f"Inference time: {1000 * (inference_stops - inference_starts):.2f}ms"
)
# full-process latency breakdown for debugging purposes
self.logger.debug(
f"Observation {observation_t.get_timestep()} | "
f"Preprocessing time: {1000 * (preprocessing_time - inference_starts):.2f}ms | "
f"Inference time: {1000 * (inference_time - preprocessing_time):.2f}ms | "
f"Postprocessing time: {1000 * (postprocessing_time - inference_time):.2f}ms | "
f"Total time: {1000 * (postprocessing_time - inference_starts):.2f}ms"
)
return action_chunk
def stop(self):
"""Stop the server"""
self._reset_server()
self.logger.info("Server stopping...")
@draccus.wrap()
def serve(cfg: PolicyServerConfig):
"""Start the PolicyServer with the given configuration.
Args:
config: PolicyServerConfig instance. If None, uses default configuration.
"""
logging.info(pformat(asdict(cfg)))
# Create the server instance first
policy_server = PolicyServer(cfg)
# Setup and start gRPC server
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
server.add_insecure_port(f"{cfg.host}:{cfg.port}")
policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")
server.start()
server.wait_for_termination()
policy_server.logger.info("Server terminated")
if __name__ == "__main__":
serve()

View File

@@ -0,0 +1,509 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example command:
```shell
python src/lerobot/scripts/server/robot_client.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
--robot.id=black \
--task="dummy" \
--server_address=127.0.0.1:8080 \
--policy_type=act \
--pretrained_name_or_path=user/model \
--policy_device=mps \
--actions_per_chunk=50 \
--chunk_size_threshold=0.5 \
--aggregate_fn_name=weighted_average \
--debug_visualize_queue_size=True
```
"""
import logging
import pickle # nosec
import threading
import time
from dataclasses import asdict
from pprint import pformat
from queue import Queue
from typing import Any, Callable, Optional
import draccus
import grpc
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs.policies import PreTrainedConfig
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
koch_follower,
make_robot_from_config,
so100_follower,
so101_follower,
)
from lerobot.scripts.server.configs import RobotClientConfig
from lerobot.scripts.server.constants import SUPPORTED_ROBOTS
from lerobot.scripts.server.helpers import (
Action,
FPSTracker,
Observation,
RawObservation,
RemotePolicyConfig,
TimedAction,
TimedObservation,
get_logger,
map_robot_keys_to_lerobot_features,
send_bytes_in_chunks,
validate_robot_cameras_for_policy,
visualize_action_queue_size,
)
from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
)
class RobotClient:
prefix = "robot_client"
logger = get_logger(prefix)
def __init__(self, config: RobotClientConfig):
"""Initialize RobotClient with unified configuration.
Args:
config: RobotClientConfig containing all configuration parameters
"""
# Store configuration
self.config = config
self.robot = make_robot_from_config(config.robot)
self.robot.connect()
lerobot_features = map_robot_keys_to_lerobot_features(self.robot)
if config.verify_robot_cameras:
# Load policy config for validation
policy_config = PreTrainedConfig.from_pretrained(config.pretrained_name_or_path)
policy_image_features = policy_config.image_features
# The cameras specified for inference must match the one supported by the policy chosen
validate_robot_cameras_for_policy(lerobot_features, policy_image_features)
# Use environment variable if server_address is not provided in config
self.server_address = config.server_address
self.policy_config = RemotePolicyConfig(
config.policy_type,
config.pretrained_name_or_path,
lerobot_features,
config.actions_per_chunk,
config.policy_device,
)
self.channel = grpc.insecure_channel(self.server_address)
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
self._running_event = threading.Event()
# Initialize client side variables
self.latest_action_lock = threading.Lock()
self.latest_action = -1
self.action_chunk_size = -1
self._chunk_size_threshold = config.chunk_size_threshold
self.action_queue = Queue()
self.action_queue_lock = threading.Lock() # Protect queue operations
self.action_queue_size = []
self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
# FPS measurement
self.fps_tracker = FPSTracker(target_fps=self.config.fps)
self.logger.info("Robot connected and ready")
# Use an event for thread-safe coordination
self.must_go = threading.Event()
self.must_go.set() # Initially set - observations qualify for direct processing
@property
def running(self):
return self._running_event.is_set()
def start(self):
"""Start the robot client and connect to the policy server"""
try:
# client-server handshake
start_time = time.perf_counter()
self.stub.Ready(async_inference_pb2.Empty())
end_time = time.perf_counter()
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
# send policy instructions
policy_config_bytes = pickle.dumps(self.policy_config)
policy_setup = async_inference_pb2.PolicySetup(data=policy_config_bytes)
self.logger.info("Sending policy instructions to policy server")
self.logger.debug(
f"Policy type: {self.policy_config.policy_type} | "
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
f"Device: {self.policy_config.device}"
)
self.stub.SendPolicyInstructions(policy_setup)
self._running_event.set()
return True
except grpc.RpcError as e:
self.logger.error(f"Failed to connect to policy server: {e}")
return False
def stop(self):
"""Stop the robot client"""
self._running_event.clear()
self.robot.disconnect()
self.logger.debug("Robot disconnected")
self.channel.close()
self.logger.debug("Client stopped, channel closed")
def send_observation(
self,
obs: TimedObservation,
) -> bool:
"""Send observation to the policy server.
Returns True if the observation was sent successfully, False otherwise."""
if not self.running:
raise RuntimeError("Client not running. Run RobotClient.start() before sending observations.")
if not isinstance(obs, TimedObservation):
raise ValueError("Input observation needs to be a TimedObservation!")
start_time = time.perf_counter()
observation_bytes = pickle.dumps(obs)
serialize_time = time.perf_counter() - start_time
self.logger.debug(f"Observation serialization time: {serialize_time:.6f}s")
try:
observation_iterator = send_bytes_in_chunks(
observation_bytes,
async_inference_pb2.Observation,
log_prefix="[CLIENT] Observation",
silent=True,
)
_ = self.stub.SendObservations(observation_iterator)
obs_timestep = obs.get_timestep()
self.logger.info(f"Sent observation #{obs_timestep} | ")
return True
except grpc.RpcError as e:
self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
return False
def _inspect_action_queue(self):
with self.action_queue_lock:
queue_size = self.action_queue.qsize()
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
return queue_size, timestamps
def _aggregate_action_queues(
self,
incoming_actions: list[TimedAction],
aggregate_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
"""Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
if aggregate_fn is None:
# default aggregate function: take the latest action
def aggregate_fn(x1, x2):
return x2
future_action_queue = Queue()
with self.action_queue_lock:
internal_queue = self.action_queue.queue
current_action_queue = {action.get_timestep(): action.get_action() for action in internal_queue}
for new_action in incoming_actions:
with self.latest_action_lock:
latest_action = self.latest_action
# New action is older than the latest action in the queue, skip it
if new_action.get_timestep() <= latest_action:
continue
# If the new action's timestep is not in the current action queue, add it directly
elif new_action.get_timestep() not in current_action_queue:
future_action_queue.put(new_action)
continue
# If the new action's timestep is in the current action queue, aggregate it
# TODO: There is probably a way to do this with broadcasting of the two action tensors
future_action_queue.put(
TimedAction(
timestamp=new_action.get_timestamp(),
timestep=new_action.get_timestep(),
action=aggregate_fn(
current_action_queue[new_action.get_timestep()], new_action.get_action()
),
)
)
with self.action_queue_lock:
self.action_queue = future_action_queue
def receive_actions(self, verbose: bool = False):
"""Receive actions from the policy server"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
self.logger.info("Action receiving thread starting")
while self.running:
try:
# Use StreamActions to get a stream of actions from the server
actions_chunk = self.stub.GetActions(async_inference_pb2.Empty())
if len(actions_chunk.data) == 0:
continue # received `Empty` from server, wait for next call
receive_time = time.time()
# Deserialize bytes back into list[TimedAction]
deserialize_start = time.perf_counter()
timed_actions = pickle.loads(actions_chunk.data) # nosec
deserialize_time = time.perf_counter() - deserialize_start
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
# Calculate network latency if we have matching observations
if len(timed_actions) > 0 and verbose:
with self.latest_action_lock:
latest_action = self.latest_action
self.logger.debug(f"Current latest action: {latest_action}")
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:
old_timesteps = [latest_action] # queue was empty
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:
old_timesteps = [latest_action] # queue was empty
# Log incoming actions
incoming_timesteps = [a.get_timestep() for a in timed_actions]
first_action_timestep = timed_actions[0].get_timestep()
server_to_client_latency = (receive_time - timed_actions[0].get_timestamp()) * 1000
self.logger.info(
f"Received action chunk for step #{first_action_timestep} | "
f"Latest action: #{latest_action} | "
f"Incoming actions: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
f"Network latency (server->client): {server_to_client_latency:.2f}ms | "
f"Deserialization time: {deserialize_time * 1000:.2f}ms"
)
# Update action queue
start_time = time.perf_counter()
self._aggregate_action_queues(timed_actions, self.config.aggregate_fn)
queue_update_time = time.perf_counter() - start_time
self.must_go.set() # after receiving actions, next empty queue triggers must-go processing!
if verbose:
# Get queue state after changes
new_size, new_timesteps = self._inspect_action_queue()
with self.latest_action_lock:
latest_action = self.latest_action
self.logger.info(
f"Latest action: {latest_action} | "
f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
)
self.logger.debug(
f"Queue update complete ({queue_update_time:.6f}s) | "
f"Before: {old_size} items | "
f"After: {new_size} items | "
)
except grpc.RpcError as e:
self.logger.error(f"Error receiving actions: {e}")
def actions_available(self):
"""Check if there are actions available in the queue"""
with self.action_queue_lock:
return not self.action_queue.empty()
def _action_tensor_to_action_dict(self, action_tensor: torch.Tensor) -> dict[str, float]:
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
return action
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
"""Reading and performing actions in local queue"""
# Lock only for queue operations
get_start = time.perf_counter()
with self.action_queue_lock:
self.action_queue_size.append(self.action_queue.qsize())
# Get action from queue
timed_action = self.action_queue.get_nowait()
get_end = time.perf_counter() - get_start
_performed_action = self.robot.send_action(
self._action_tensor_to_action_dict(timed_action.get_action())
)
with self.latest_action_lock:
self.latest_action = timed_action.get_timestep()
if verbose:
with self.action_queue_lock:
current_queue_size = self.action_queue.qsize()
self.logger.debug(
f"Ts={timed_action.get_timestamp()} | "
f"Action #{timed_action.get_timestep()} performed | "
f"Queue size: {current_queue_size}"
)
self.logger.debug(
f"Popping action from queue to perform took {get_end:.6f}s | Queue size: {current_queue_size}"
)
return _performed_action
def _ready_to_send_observation(self):
"""Flags when the client is ready to send an observation"""
with self.action_queue_lock:
return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
def control_loop_observation(self, task: str, verbose: bool = False) -> RawObservation:
try:
# Get serialized observation bytes from the function
start_time = time.perf_counter()
raw_observation: RawObservation = self.robot.get_observation()
raw_observation["task"] = task
with self.latest_action_lock:
latest_action = self.latest_action
observation = TimedObservation(
timestamp=time.time(), # need time.time() to compare timestamps across client and server
observation=raw_observation,
timestep=max(latest_action, 0),
)
obs_capture_time = time.perf_counter() - start_time
# If there are no actions left in the queue, the observation must go through processing!
with self.action_queue_lock:
observation.must_go = self.must_go.is_set() and self.action_queue.empty()
current_queue_size = self.action_queue.qsize()
_ = self.send_observation(observation)
self.logger.debug(f"QUEUE SIZE: {current_queue_size} (Must go: {observation.must_go})")
if observation.must_go:
# must-go event will be set again after receiving actions
self.must_go.clear()
if verbose:
# Calculate comprehensive FPS metrics
fps_metrics = self.fps_tracker.calculate_fps_metrics(observation.get_timestamp())
self.logger.info(
f"Obs #{observation.get_timestep()} | "
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | "
f"Target: {fps_metrics['target_fps']:.2f}"
)
self.logger.debug(
f"Ts={observation.get_timestamp():.6f} | Capturing observation took {obs_capture_time:.6f}s"
)
return raw_observation
except Exception as e:
self.logger.error(f"Error in observation sender: {e}")
def control_loop(self, task: str, verbose: bool = False) -> tuple[Observation, Action]:
"""Combined function for executing actions and streaming observations"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
self.logger.info("Control loop thread starting")
_performed_action = None
_captured_observation = None
while self.running:
control_loop_start = time.perf_counter()
"""Control loop: (1) Performing actions, when available"""
if self.actions_available():
_performed_action = self.control_loop_action(verbose)
"""Control loop: (2) Streaming observations to the remote policy server"""
if self._ready_to_send_observation():
_captured_observation = self.control_loop_observation(task, verbose)
self.logger.info(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
# Dynamically adjust sleep time to maintain the desired control frequency
time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start)))
return _captured_observation, _performed_action
@draccus.wrap()
def async_client(cfg: RobotClientConfig):
logging.info(pformat(asdict(cfg)))
if cfg.robot.type not in SUPPORTED_ROBOTS:
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
client = RobotClient(cfg)
if client.start():
client.logger.info("Starting action receiver thread...")
# Create and start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
# Start action receiver thread
action_receiver_thread.start()
try:
# The main thread runs the control loop
client.control_loop(task=cfg.task)
finally:
client.stop()
action_receiver_thread.join()
if cfg.debug_visualize_queue_size:
visualize_action_queue_size(client.action_queue_size)
client.logger.info("Client stopped")
if __name__ == "__main__":
async_client() # run the client

View File

@@ -0,0 +1,59 @@
// fmt: off
// flake8: noqa
// !/usr/bin/env python
// Copyright 2024 The HuggingFace Inc. team.
// All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package async_inference;
// AsyncInference: from Robot perspective
// Robot send observations to & executes action received from a remote Policy server
service AsyncInference {
// Robot -> Policy to share observations with a remote inference server
// Policy -> Robot to share actions predicted for given observations
rpc SendObservations(stream Observation) returns (Empty);
rpc GetActions(Empty) returns (Actions);
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
rpc Ready(Empty) returns (Empty);
rpc Stop(Empty) returns (Empty);
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
TRANSFER_MIDDLE = 2;
TRANSFER_END = 3;
}
// Messages
message Observation {
// sent by Robot, to remote Policy
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
bytes data = 2;
}
message Actions {
// sent by remote Policy, to Robot
bytes data = 1;
}
message PolicySetup {
// sent by Robot to remote server, to init Policy
bytes data = 1;
}
message Empty {}

View File

@@ -0,0 +1,45 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: async_inference.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'async_inference.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdd\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12>\n\nGetActions\x12\x16.async_inference.Empty\x1a\x18.async_inference.Actions\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Empty\x12\x36\n\x04Stop\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=190
_globals['_TRANSFERSTATE']._serialized_end=286
_globals['_OBSERVATION']._serialized_start=42
_globals['_OBSERVATION']._serialized_end=125
_globals['_ACTIONS']._serialized_start=127
_globals['_ACTIONS']._serialized_end=150
_globals['_POLICYSETUP']._serialized_start=152
_globals['_POLICYSETUP']._serialized_end=179
_globals['_EMPTY']._serialized_start=181
_globals['_EMPTY']._serialized_end=188
_globals['_ASYNCINFERENCE']._serialized_start=289
_globals['_ASYNCINFERENCE']._serialized_end=638
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,277 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
from lerobot.transport import async_inference_pb2 as async__inference__pb2
GRPC_GENERATED_VERSION = '1.71.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class AsyncInferenceStub:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendObservations = channel.stream_unary(
'/async_inference.AsyncInference/SendObservations',
request_serializer=async__inference__pb2.Observation.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.GetActions = channel.unary_unary(
'/async_inference.AsyncInference/GetActions',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Actions.FromString,
_registered_method=True)
self.SendPolicyInstructions = channel.unary_unary(
'/async_inference.AsyncInference/SendPolicyInstructions',
request_serializer=async__inference__pb2.PolicySetup.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/async_inference.AsyncInference/Ready',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.Stop = channel.unary_unary(
'/async_inference.AsyncInference/Stop',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
class AsyncInferenceServicer:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def SendObservations(self, request_iterator, context):
"""Robot -> Policy to share observations with a remote inference server
Policy -> Robot to share actions predicted for given observations
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetActions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendPolicyInstructions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Stop(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AsyncInferenceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendObservations': grpc.stream_unary_rpc_method_handler(
servicer.SendObservations,
request_deserializer=async__inference__pb2.Observation.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'GetActions': grpc.unary_unary_rpc_method_handler(
servicer.GetActions,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Actions.SerializeToString,
),
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
servicer.SendPolicyInstructions,
request_deserializer=async__inference__pb2.PolicySetup.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'Stop': grpc.unary_unary_rpc_method_handler(
servicer.Stop,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'async_inference.AsyncInference', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class AsyncInference:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
@staticmethod
def SendObservations(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/async_inference.AsyncInference/SendObservations',
async__inference__pb2.Observation.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def GetActions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/GetActions',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Actions.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendPolicyInstructions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/SendPolicyInstructions',
async__inference__pb2.PolicySetup.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Ready',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Stop(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Stop',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)