Add Async Inference (#1196)

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Francesco Capuano
2025-07-10 10:39:11 +02:00
committed by GitHub
parent ce2b9724bf
commit 30c161006d
15 changed files with 3266 additions and 1 deletions

View File

@@ -17,6 +17,8 @@
title: Train a Robot with RL
- local: hilserl_sim
title: Train RL in Simulation
- local: async
title: Use Async Inference
title: "Tutorials"
- sections:
- local: smolvla

272
docs/source/async.mdx Normal file
View File

@@ -0,0 +1,272 @@
# Asynchronous Inference
With our [SmolVLA](https://huggingface.co/papers/2506.01844) we introduced a new way to run inference on real-world robots, **decoupling action prediction from action execution**.
In this tutorial, we'll show how to use asynchronous inference (_async inference_) using a finetuned version of SmolVLA, and all the policies supported by LeRobot.
**Try async inference with all the policies** supported by LeRobot!
**What you'll learn:**
1. Why asynchronous inference matters and how it compares to, more traditional, sequential inference.
2. How to spin-up a `PolicyServer` and connect a `RobotClient` from the same machine, and even over the network.
3. How to tune key parameters (`actions_per_chunk`, `chunk_size_threshold`) for your robot and policy.
If you get stuck, hop into our [Discord community](https://discord.gg/s3KuuzsPFb)!
In a nutshell: with *async inference*, your robot keeps acting while the policy server is already busy computing the next chunk of actions---eliminating "wait-for-inference" lags and unlocking smoother, more reactive behaviours.
This is fundamentally different from synchronous inference (sync), where the robot stays idle while the policy computes the next chunk of actions.
---
## Getting started with async inference
You can read more information on asynchronous inference in our [blogpost](NOTE:blogpost). Here, we report a getting started guide meant to help you setup and run asynchronous inference in your setup.
First, install `lerobot` with the `async` tag, to install the extra dependencies required to run async inference.
```shell
pip install -e ".[async]"
```
Then, spin up a policy server (in one terminal, or in a separate machine) specifying the host address and port for the client to connect to.
You can spin up a policy server running:
```shell
python src/lerobot/scripts/server/policy_server.py \
--host=127.0.0.1 \
--port=8080 \
```
This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with:
```shell
python src/lerobot/scripts/server/robot_client.py \
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
--robot.type=so100_follower \ # ROBOT: your robot type
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
```
In summary, you need to specify instructions for:
- `SERVER`: the address and port of the policy server
- `ROBOT`: the type of robot to connect to, the port to connect to, and the local `id` of the robot
- `POLICY`: the type of policy to run, and the model name/path on server to the checkpoint to run. You also need to specify which device should the sever be using, and how many actions to output at once (capped at the policy max actions value).
- `CLIENT`: the threshold for the chunk size before sending a new observation to the server, and the function to aggregate actions on overlapping portions. Optionally, you can also visualize the queue size at runtime, to help you tune the `CLIENT` parameters.
Importantly,
- `actions_per_chunk` and `chunk_size_threshold` are key parameters to tune for your setup.
- `aggregate_fn_name` is the function to aggregate actions on overlapping portions. You can either add a new one to a registry of functions, or add your own in `robot_client.py` (see [here](NOTE:addlinktoLOC))
- `debug_visualize_queue_size` is a useful tool to tune the `CLIENT` parameters.
Done! You should see your robot moving around by now 😉
---
## Async vs. synchronous inference
Synchronous inference relies on interleaving action chunk prediction and action execution. This inherently results in *idle frames*, frames where the robot awaits idle the policy's output: a new action chunk.
In turn, inference is plagued by evident real-time lags, where the robot simply stops acting due to the lack of available actions.
With robotics models increasing in size, this problem risks becoming only more severe.
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/sync.png" width="80%"></img>
</p>
<p align="center"><i>Synchronous inference</i> makes the robot idle while the policy is computing the next chunk of actions.</p>
To overcome this, we design async inference, a paradigm where action planning and execution are decoupled, resulting in (1) higher adaptability and, most importantly, (2) no idle frames.
Crucially, with async inference, the next action chunk is computed *before* the current one is exhausted, resulting in no idleness.
Higher adaptability is ensured by aggregating the different action chunks on overlapping portions, obtaining an up-to-date plan and a tighter control loop.
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/async.png" width="80%"></img>
</p>
<p align="center"><i>Asynchronous inference</i> results in no idleness because the next chunk is computed before the current chunk is exhausted.</p>
---
## Start the Policy Server
Policy servers are wrappers around a `PreTrainedPolicy` interfacing them with observations coming from a robot client.
Policy servers are initialized as empty containers which are populated with the requested policy specified in the initial handshake between the robot client and the policy server.
As such, spinning up a policy server is as easy as specifying the host address and port. If you're running the policy server on the same machine as the robot client, you can use `localhost` as the host address.
<hfoptions id="start_policy_server">
<hfoption id="Command">
```bash
python -m lerobot.scripts.server.policy_server \
--host="localhost" \
--port=8080
```
</hfoption>
<hfoption id="API example">
```python
from lerobot.scripts.server.configs import PolicyServerConfig
from lerobot.scripts.server.policy_server import serve
config = PolicyServerConfig(
host="localhost",
port=8080,
)
serve(config)
```
</hfoption>
</hfoptions>
This listens on `localhost:8080` for an incoming connection from the associated`RobotClient`, which will communicate which policy to run during the first client-server handshake.
---
## Launch the Robot Client
`RobotClient` is a wrapper around a `Robot` instance, which `RobotClient` connects to the (possibly remote) `PolicyServer`.
The `RobotClient` streams observations to the `PolicyServer`, and receives action chunks obtained running inference on the server (which we assume to have better computational resources than the robot controller).
<hfoptions id="start_robot_client">
<hfoption id="Command">
```bash
python src/lerobot/scripts/server/robot_client.py \
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
--robot.type=so100_follower \ # ROBOT: your robot type
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
```
</hfoption>
<hfoption id="API example">
```python
import threading
from lerobot.robots.so100_follower import SO100FollowerConfig
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.scripts.server.configs import RobotClientConfig
from lerobot.scripts.server.robot_client import RobotClient
from lerobot.scripts.server.helpers import visualize_action_queue_size
# 1. Create the robot instance
"""Check out the cameras available in your setup by running `python lerobot/find_cameras.py`"""
# these cameras must match the ones expected by the policy
# check the config.json on the Hub for the policy you are using
camera_cfg = {
"top": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
}
robot_cfg = SO100FollowerConfig(
port="/dev/tty.usbmodem585A0076841",
id="follower_so100",
cameras=camera_cfg
)
# 3. Create client configuration
client_cfg = RobotClientConfig(
robot=robot_cfg,
server_address="localhost:8080",
policy_device="mps",
policy_type="smolvla",
pretrained_name_or_path="fracapuano/smolvla_async",
chunk_size_threshold=0.5,
actions_per_chunk=50, # make sure this is less than the max actions of the policy
)
# 4. Create and start client
client = RobotClient(client_cfg)
# 5. Specify the task
task = "Don't do anything, stay still"
if client.start():
# Start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
action_receiver_thread.start()
try:
# Run the control loop
client.control_loop(task)
except KeyboardInterrupt:
client.stop()
action_receiver_thread.join()
# (Optionally) plot the action queue size
visualize_action_queue_size(client.action_queue_size)
```
</hfoption>
</hfoptions>
The following two parameters are key in every setup:
<table>
<thead>
<tr>
<th>Hyperparameter</th>
<th>Default</th>
<th>What it does</th>
</tr>
</thead>
<tbody>
<tr>
<td><code>actions_per_chunk</code></td>
<td>50</td>
<td>How many actions the policy outputs at once. Typical values: 10-50.</td>
</tr>
<tr>
<td><code>chunk_size_threshold</code></td>
<td>0.7</td>
<td>When the queue is ≤ 50% full, the client sends a fresh observation. Value in [0, 1].</td>
</tr>
</tbody>
</table>
<Tip>
Different values of `actions_per_chunk` and `chunk_size_threshold` do result in different behaviours.
</Tip>
On the one hand, increasing the value of `actions_per_chunk` will result in reducing the likelihood of ending up with no actions to execute, as more actions will be available when the new chunk is computed.
However, larger values of `actions_per_chunk` might also result in less precise actions, due to the compounding errors consequent to predicting actions over longer timespans.
On the other hand, increasing the value of `chunk_size_threshold` will result in sending out to the `PolicyServer` observations for inference more often, resulting in a larger number of updates action chunks, overlapping on significant portions. This results in high adaptability, in the limit predicting one action chunk for each observation, which is in turn only marginally consumed while a new one is produced.
This option does also put more pressure on the inference pipeline, as a consequence of the many requests. Conversely, values of `chunk_size_threshold` close to 0.0 collapse to the synchronous edge case, whereby new observations are only sent out whenever the current chunk is exhausted.
We found the default values of `actions_per_chunk` and `chunk_size_threshold` to work well in the experiments we developed for the [SmolVLA paper](https://huggingface.co/papers/2506.01844), but recommend experimenting with different values to find the best fit for your setup.
### Tuning async inference for your setup
1. **Choose your computational resources carefully.** [PI0](https://huggingface.co/lerobot/pi0) occupies 14GB of memory at inference time, while [SmolVLA](https://huggingface.co/lerobot/smolvla_base) requires only ~2GB. You should identify the best computational resource for your use case keeping in mind smaller policies require less computational resources. The combination of policy and device used (CPU-intensive, using MPS, or the number of CUDA cores on a given NVIDIA GPU) directly impacts the average inference latency you should expect.
2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
3. **Adjust `chunk_size_threshold`**.
- Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug-visualize-queue-size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/queues.png" width="80%"></img>
</p>
<p align="center"><i>The action queue size is plotted at runtime when the `--debug-visualize-queue-size` flag is passed, for various levels of `chunk_size_threshold` (`g` in the SmolVLA paper).</i></p>
---
## Conclusion
Asynchronous inference represents a significant advancement in real-time robotics control, addressing the fundamental challenge of inference latency that has long plagued robotics applications. Through this tutorial, you've learned how to implement a complete async inference pipeline that eliminates idle frames and enables smoother, more reactive robot behaviors.
**Key Takeaways:**
- **Paradigm Shift**: Async inference decouples action prediction from execution, allowing robots to continue acting while new action chunks are computed in parallel
- **Performance Benefits**: Eliminates "wait-for-inference" lags that are inherent in synchronous approaches, becoming increasingly important as policy models grow larger
- **Flexible Architecture**: The server-client design enables distributed computing, where inference can run on powerful remote hardware while maintaining real-time robot control
- **Tunable Parameters**: Success depends on properly configuring `actions_per_chunk` and `chunk_size_threshold` for your specific hardware, policy, and task requirements
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues).

View File

@@ -46,7 +46,7 @@ classifiers = [
]
dependencies = [
"cmake>=3.29.0.1",
"datasets>=2.19.0",
"datasets>=2.19.0,<=3.6.0",
"deepdiff>=7.0.1",
"diffusers>=0.27.2",
"draccus==0.10.0",
@@ -105,6 +105,7 @@ hilserl = ["transformers>=4.50.3", "gym-hil>=0.1.9", "protobuf>=5.29.3", "grpcio
umi = ["imagecodecs>=2024.1.1"]
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
async = ["grpcio==1.71.0", "matplotlib>=3.10.3"]
[tool.poetry]
requires-poetry = ">=2.1"

View File

@@ -0,0 +1,197 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Callable
import torch
from lerobot.robots.config import RobotConfig
from lerobot.scripts.server.constants import (
DEFAULT_FPS,
DEFAULT_INFERENCE_LATENCY,
DEFAULT_OBS_QUEUE_TIMEOUT,
)
# Aggregate function registry for CLI usage
AGGREGATE_FUNCTIONS = {
"weighted_average": lambda old, new: 0.3 * old + 0.7 * new,
"latest_only": lambda old, new: new,
"average": lambda old, new: 0.5 * old + 0.5 * new,
"conservative": lambda old, new: 0.7 * old + 0.3 * new,
}
def get_aggregate_function(name: str) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
"""Get aggregate function by name from registry."""
if name not in AGGREGATE_FUNCTIONS:
available = list(AGGREGATE_FUNCTIONS.keys())
raise ValueError(f"Unknown aggregate function '{name}'. Available: {available}")
return AGGREGATE_FUNCTIONS[name]
@dataclass
class PolicyServerConfig:
"""Configuration for PolicyServer.
This class defines all configurable parameters for the PolicyServer,
including networking settings and action chunking specifications.
"""
# Networking configuration
host: str = field(default="localhost", metadata={"help": "Host address to bind the server to"})
port: int = field(default=8080, metadata={"help": "Port number to bind the server to"})
# Timing configuration
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
inference_latency: float = field(
default=DEFAULT_INFERENCE_LATENCY, metadata={"help": "Target inference latency in seconds"}
)
obs_queue_timeout: float = field(
default=DEFAULT_OBS_QUEUE_TIMEOUT, metadata={"help": "Timeout for observation queue in seconds"}
)
def __post_init__(self):
"""Validate configuration after initialization."""
if self.port < 1 or self.port > 65535:
raise ValueError(f"Port must be between 1 and 65535, got {self.port}")
if self.environment_dt <= 0:
raise ValueError(f"environment_dt must be positive, got {self.environment_dt}")
if self.inference_latency < 0:
raise ValueError(f"inference_latency must be non-negative, got {self.inference_latency}")
if self.obs_queue_timeout < 0:
raise ValueError(f"obs_queue_timeout must be non-negative, got {self.obs_queue_timeout}")
@classmethod
def from_dict(cls, config_dict: dict) -> "PolicyServerConfig":
"""Create a PolicyServerConfig from a dictionary."""
return cls(**config_dict)
@property
def environment_dt(self) -> float:
"""Environment time step, in seconds"""
return 1 / self.fps
def to_dict(self) -> dict:
"""Convert the configuration to a dictionary."""
return {
"host": self.host,
"port": self.port,
"fps": self.fps,
"environment_dt": self.environment_dt,
"inference_latency": self.inference_latency,
}
@dataclass
class RobotClientConfig:
"""Configuration for RobotClient.
This class defines all configurable parameters for the RobotClient,
including network connection, policy settings, and control behavior.
"""
# Policy configuration
policy_type: str = field(metadata={"help": "Type of policy to use"})
pretrained_name_or_path: str = field(metadata={"help": "Pretrained model name or path"})
# Robot configuration (for CLI usage - robot instance will be created from this)
robot: RobotConfig = field(metadata={"help": "Robot configuration"})
# Policies typically output K actions at max, but we can use less to avoid wasting bandwidth (as actions
# would be aggregated on the client side anyway, depending on the value of `chunk_size_threshold`)
actions_per_chunk: int = field(metadata={"help": "Number of actions per chunk"})
# Task instruction for the robot to execute (e.g., 'fold my tshirt')
task: str = field(default="", metadata={"help": "Task instruction for the robot to execute"})
# Network configuration
server_address: str = field(default="localhost:8080", metadata={"help": "Server address to connect to"})
# Device configuration
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
# Control behavior configuration
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
# Aggregate function configuration (CLI-compatible)
aggregate_fn_name: str = field(
default="weighted_average",
metadata={"help": f"Name of aggregate function to use. Options: {list(AGGREGATE_FUNCTIONS.keys())}"},
)
# Debug configuration
debug_visualize_queue_size: bool = field(
default=False, metadata={"help": "Visualize the action queue size"}
)
# Verification configuration
verify_robot_cameras: bool = field(
default=True, metadata={"help": "Verify that the robot cameras match the policy cameras"}
)
@property
def environment_dt(self) -> float:
"""Environment time step, in seconds"""
return 1 / self.fps
def __post_init__(self):
"""Validate configuration after initialization."""
if not self.server_address:
raise ValueError("server_address cannot be empty")
if not self.policy_type:
raise ValueError("policy_type cannot be empty")
if not self.pretrained_name_or_path:
raise ValueError("pretrained_name_or_path cannot be empty")
if not self.policy_device:
raise ValueError("policy_device cannot be empty")
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
if self.fps <= 0:
raise ValueError(f"fps must be positive, got {self.fps}")
if self.actions_per_chunk <= 0:
raise ValueError(f"actions_per_chunk must be positive, got {self.actions_per_chunk}")
self.aggregate_fn = get_aggregate_function(self.aggregate_fn_name)
@classmethod
def from_dict(cls, config_dict: dict) -> "RobotClientConfig":
"""Create a RobotClientConfig from a dictionary."""
return cls(**config_dict)
def to_dict(self) -> dict:
"""Convert the configuration to a dictionary."""
return {
"server_address": self.server_address,
"policy_type": self.policy_type,
"pretrained_name_or_path": self.pretrained_name_or_path,
"policy_device": self.policy_device,
"chunk_size_threshold": self.chunk_size_threshold,
"fps": self.fps,
"actions_per_chunk": self.actions_per_chunk,
"task": self.task,
"debug_visualize_queue_size": self.debug_visualize_queue_size,
"aggregate_fn_name": self.aggregate_fn_name,
}

View File

@@ -0,0 +1,29 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Client side: The environment evolves with a time resolution equal to 1/fps"""
DEFAULT_FPS = 30
"""Server side: Running inference on (at most) 1/fps"""
DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
"""Server side: Timeout for observation queue in seconds"""
DEFAULT_OBS_QUEUE_TIMEOUT = 2
# All action chunking policies
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"]
# TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"]

View File

@@ -0,0 +1,386 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import logging.handlers
import os
import time
from dataclasses import dataclass
from pathlib import Path
from threading import Event
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_IMAGES, OBS_STATE
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
from lerobot.robots.robot import Robot
from lerobot.transport import async_inference_pb2
from lerobot.transport.utils import bytes_buffer_size
from lerobot.utils.utils import init_logging
Action = torch.Tensor
ActionChunk = torch.Tensor
# observation as received from the robot
RawObservation = dict[str, torch.Tensor]
# observation as those recorded in LeRobot dataset (keys are different)
LeRobotObservation = dict[str, torch.Tensor]
# observation, ready for policy inference (image keys resized)
Observation = dict[str, torch.Tensor]
def visualize_action_queue_size(action_queue_size: list[int]) -> None:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.set_title("Action Queue Size Over Time")
ax.set_xlabel("Environment steps")
ax.set_ylabel("Action Queue Size")
ax.set_ylim(0, max(action_queue_size) * 1.1)
ax.grid(True, alpha=0.3)
ax.plot(range(len(action_queue_size)), action_queue_size)
plt.show()
def validate_robot_cameras_for_policy(
lerobot_observation_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature]
) -> None:
image_keys = list(filter(is_image_key, lerobot_observation_features))
assert set(image_keys) == set(policy_image_features.keys()), (
f"Policy image features must match robot cameras! Received {list(policy_image_features.keys())} != {image_keys}"
)
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
return hw_to_dataset_features(robot.observation_features, "observation", use_video=False)
def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)
def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int, int, int]) -> torch.tensor:
assert image.ndim == 3, f"Image must be (C, H, W)! Received {image.shape}"
# (H, W, C) -> (C, H, W) for resizing from robot obsevation resolution to policy image resolution
image = image.permute(2, 0, 1)
dims = (resize_dims[1], resize_dims[2])
# Add batch dimension for interpolate: (C, H, W) -> (1, C, H, W)
image_batched = image.unsqueeze(0)
# Interpolate and remove batch dimension: (1, C, H, W) -> (C, H, W)
resized = torch.nn.functional.interpolate(image_batched, size=dims, mode="bilinear", align_corners=False)
return resized.squeeze(0)
def raw_observation_to_observation(
raw_observation: RawObservation,
lerobot_features: dict[str, dict],
policy_image_features: dict[str, PolicyFeature],
device: str,
) -> Observation:
observation = {}
observation = prepare_raw_observation(raw_observation, lerobot_features, policy_image_features)
for k, v in observation.items():
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations
if "image" in k:
# Policy expects images in shape (B, C, H, W)
observation[k] = prepare_image(v).unsqueeze(0).to(device)
else:
observation[k] = v.to(device)
else:
observation[k] = v
return observation
def prepare_image(image: torch.Tensor) -> torch.Tensor:
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
image = image.type(torch.float32) / 255
image = image.contiguous()
return image
def extract_state_from_raw_observation(
lerobot_obs: RawObservation,
) -> torch.Tensor:
"""Extract the state from a raw observation."""
state = torch.tensor(lerobot_obs[OBS_STATE])
if state.ndim == 1:
state = state.unsqueeze(0)
return state
def extract_images_from_raw_observation(
lerobot_obs: RawObservation,
camera_key: str,
) -> dict[str, torch.Tensor]:
"""Extract the images from a raw observation."""
return torch.tensor(lerobot_obs[camera_key])
def make_lerobot_observation(
robot_obs: RawObservation,
lerobot_features: dict[str, dict],
) -> LeRobotObservation:
"""Make a lerobot observation from a raw observation."""
return build_dataset_frame(lerobot_features, robot_obs, prefix="observation")
def prepare_raw_observation(
robot_obs: RawObservation,
lerobot_features: dict[str, dict],
policy_image_features: dict[str, PolicyFeature],
) -> Observation:
"""Matches keys from the raw robot_obs dict to the keys expected by a given policy (passed as
policy_image_features)."""
# 1. {motor.pos1:value1, motor.pos2:value2, ..., laptop:np.ndarray} ->
# -> {observation.state:[value1,value2,...], observation.images.laptop:np.ndarray}
lerobot_obs = make_lerobot_observation(robot_obs, lerobot_features)
# 2. Greps all observation.images.<> keys
image_keys = list(filter(is_image_key, lerobot_obs))
# state's shape is expected as (B, state_dim)
state_dict = {OBS_STATE: extract_state_from_raw_observation(lerobot_obs)}
image_dict = {
image_k: extract_images_from_raw_observation(lerobot_obs, image_k) for image_k in image_keys
}
# Turns the image features to (C, H, W) with H, W matching the policy image features.
# This reduces the resolution of the images
image_dict = {
key: resize_robot_observation_image(torch.tensor(lerobot_obs[key]), policy_image_features[key].shape)
for key in image_keys
}
if "task" in robot_obs:
state_dict["task"] = robot_obs["task"]
return {**state_dict, **image_dict}
def get_logger(name: str, log_to_file: bool = True) -> logging.Logger:
"""
Get a logger using the standardized logging setup from utils.py.
Args:
name: Logger name (e.g., 'policy_server', 'robot_client')
log_to_file: Whether to also log to a file
Returns:
Configured logger instance
"""
# Create logs directory if logging to file
if log_to_file:
os.makedirs("logs", exist_ok=True)
log_file = Path(f"logs/{name}_{int(time.time())}.log")
else:
log_file = None
# Initialize the standardized logging
init_logging(log_file=log_file, display_pid=False)
# Return a named logger
return logging.getLogger(name)
@dataclass
class TimedData:
"""A data object with timestamp and timestep information.
Args:
timestamp: Unix timestamp relative to data's creation.
data: The actual data to wrap a timestamp around.
timestep: The timestep of the data.
"""
timestamp: float
timestep: int
def get_timestamp(self):
return self.timestamp
def get_timestep(self):
return self.timestep
@dataclass
class TimedAction(TimedData):
action: Action
def get_action(self):
return self.action
@dataclass
class TimedObservation(TimedData):
observation: RawObservation
must_go: bool = False
def get_observation(self):
return self.observation
@dataclass
class FPSTracker:
"""Utility class to track FPS metrics over time."""
target_fps: float
first_timestamp: float = None
total_obs_count: int = 0
def calculate_fps_metrics(self, current_timestamp: float) -> dict[str, float]:
"""Calculate average FPS vs target"""
self.total_obs_count += 1
# Initialize first observation time
if self.first_timestamp is None:
self.first_timestamp = current_timestamp
# Calculate overall average FPS (since start)
total_duration = current_timestamp - self.first_timestamp
avg_fps = (self.total_obs_count - 1) / total_duration if total_duration > 1e-6 else 0.0
return {"avg_fps": avg_fps, "target_fps": self.target_fps}
def reset(self):
"""Reset the FPS tracker state"""
self.first_timestamp = None
self.total_obs_count = 0
@dataclass
class RemotePolicyConfig:
policy_type: str
pretrained_name_or_path: str
lerobot_features: dict[str, PolicyFeature]
actions_per_chunk: int
device: str = "cpu"
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
"""Check if two observation states are similar, under a tolerance threshold"""
return bool(torch.linalg.norm(obs1_state - obs2_state) < atol)
def observations_similar(
obs1: TimedObservation, obs2: TimedObservation, lerobot_features: dict[str, dict], atol: float = 1
) -> bool:
"""Check if two observations are similar, under a tolerance threshold. Measures distance between
observations as the difference in joint-space between the two observations.
NOTE(fracapuano): This is a very simple check, and it is enough for the current use case.
An immediate next step is to use (fast) perceptual difference metrics comparing some camera views,
to surpass this joint-space similarity check.
"""
obs1_state = extract_state_from_raw_observation(
make_lerobot_observation(obs1.get_observation(), lerobot_features)
)
obs2_state = extract_state_from_raw_observation(
make_lerobot_observation(obs2.get_observation(), lerobot_features)
)
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
def send_bytes_in_chunks(
buffer: bytes,
message_class: Any,
log_prefix: str = "",
silent: bool = True,
chunk_size: int = 3 * 1024 * 1024,
):
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.send_bytes_in_chunks. Duplication can't be avoided if we
# don't use a unique class for messages sent (due to the different transfer states sent). Also, I'd want more control over the
# chunk size as I am using it to send image observations.
buffer = io.BytesIO(buffer)
size_in_bytes = bytes_buffer_size(buffer)
sent_bytes = 0
logging_method = logging.info if not silent else logging.debug
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = async_inference_pb2.TransferState.TRANSFER_MIDDLE
if sent_bytes + chunk_size >= size_in_bytes:
transfer_state = async_inference_pb2.TransferState.TRANSFER_END
elif sent_bytes == 0:
transfer_state = async_inference_pb2.TransferState.TRANSFER_BEGIN
size_to_read = min(chunk_size, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read)
yield message_class(transfer_state=transfer_state, data=chunk)
sent_bytes += size_to_read
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(
iterator, continue_receiving: Event, logger: logging.Logger, log_prefix: str = ""
): # type: ignore
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.receive_bytes_in_chunks. Duplication can't be avoided if we
# don't use a unique class for messages sent (due to the different transfer states sent). Also, on the server side the logic for receiving
# is opposite then the HIL-SERL design (my event showcases keeping on running instead of shutdown)
bytes_buffer = io.BytesIO()
step = 0
logger.info(f"{log_prefix} Starting receiver")
for item in iterator:
logger.debug(f"{log_prefix} Received item")
if not continue_receiving.is_set():
logger.info(f"{log_prefix} Shutting down receiver")
return
if item.transfer_state == async_inference_pb2.TransferState.TRANSFER_BEGIN:
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
bytes_buffer.write(item.data)
logger.debug(f"{log_prefix} Received data at step 0")
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_MIDDLE:
bytes_buffer.write(item.data)
step += 1
logger.debug(f"{log_prefix} Received data at step {step}")
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_END:
bytes_buffer.write(item.data)
logger.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
complete_bytes = bytes_buffer.getvalue()
bytes_buffer.seek(0)
bytes_buffer.truncate(0)
logger.debug(f"{log_prefix} Queue updated")
return complete_bytes
else:
logger.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}")
raise ValueError(f"Received unknown transfer state {item.transfer_state}")

View File

@@ -0,0 +1,403 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example:
```shell
python src/lerobot/scripts/server/policy_server.py \
--host=127.0.0.1 \
--port=8080 \
--fps=30 \
--inference_latency=0.033 \
--obs_queue_timeout=1
```
"""
import logging
import pickle # nosec
import threading
import time
from concurrent import futures
from dataclasses import asdict
from pprint import pformat
from queue import Empty, Queue
import draccus
import grpc
import torch
from lerobot.policies.factory import get_policy_class
from lerobot.scripts.server.configs import PolicyServerConfig
from lerobot.scripts.server.constants import SUPPORTED_POLICIES
from lerobot.scripts.server.helpers import (
FPSTracker,
Observation,
RemotePolicyConfig,
TimedAction,
TimedObservation,
get_logger,
observations_similar,
raw_observation_to_observation,
receive_bytes_in_chunks,
)
from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
)
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
prefix = "policy_server"
logger = get_logger(prefix)
def __init__(self, config: PolicyServerConfig):
self.config = config
self._running_event = threading.Event()
# FPS measurement
self.fps_tracker = FPSTracker(target_fps=config.fps)
self.observation_queue = Queue(maxsize=1)
self._predicted_timesteps_lock = threading.Lock()
self._predicted_timesteps = set()
self.last_processed_obs = None
# Attributes will be set by SendPolicyInstructions
self.device = None
self.policy_type = None
self.lerobot_features = None
self.actions_per_chunk = None
self.policy = None
@property
def running(self):
return self._running_event.is_set()
@property
def policy_image_features(self):
return self.policy.config.image_features
def _reset_server(self) -> None:
"""Flushes server state when new client connects."""
# only running inference on the latest observation received by the server
self._running_event.clear()
self.observation_queue = Queue(maxsize=1)
with self._predicted_timesteps_lock:
self._predicted_timesteps = set()
def Ready(self, request, context): # noqa: N802
client_id = context.peer()
self.logger.info(f"Client {client_id} connected and ready")
self._reset_server()
self._running_event.set()
return async_inference_pb2.Empty()
def SendPolicyInstructions(self, request, context): # noqa: N802
"""Receive policy instructions from the robot client"""
if not self.running:
self.logger.warning("Server is not running. Ignoring policy instructions.")
return async_inference_pb2.Empty()
client_id = context.peer()
policy_specs = pickle.loads(request.data) # nosec
if not isinstance(policy_specs, RemotePolicyConfig):
raise TypeError(f"Policy specs must be a RemotePolicyConfig. Got {type(policy_specs)}")
if policy_specs.policy_type not in SUPPORTED_POLICIES:
raise ValueError(
f"Policy type {policy_specs.policy_type} not supported. "
f"Supported policies: {SUPPORTED_POLICIES}"
)
self.logger.info(
f"Receiving policy instructions from {client_id} | "
f"Policy type: {policy_specs.policy_type} | "
f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | "
f"Actions per chunk: {policy_specs.actions_per_chunk} | "
f"Device: {policy_specs.device}"
)
self.device = policy_specs.device
self.policy_type = policy_specs.policy_type # act, pi0, etc.
self.lerobot_features = policy_specs.lerobot_features
self.actions_per_chunk = policy_specs.actions_per_chunk
policy_class = get_policy_class(self.policy_type)
start = time.perf_counter()
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
self.policy.to(self.device)
end = time.perf_counter()
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
return async_inference_pb2.Empty()
def SendObservations(self, request_iterator, context): # noqa: N802
"""Receive observations from the robot client"""
client_id = context.peer()
self.logger.debug(f"Receiving observations from {client_id}")
receive_time = time.time() # comparing timestamps so need time.time()
start_deserialize = time.perf_counter()
received_bytes = receive_bytes_in_chunks(
request_iterator, self._running_event, self.logger
) # blocking call while looping over request_iterator
timed_observation = pickle.loads(received_bytes) # nosec
deserialize_time = time.perf_counter() - start_deserialize
self.logger.debug(f"Received observation #{timed_observation.get_timestep()}")
obs_timestep = timed_observation.get_timestep()
obs_timestamp = timed_observation.get_timestamp()
# Calculate FPS metrics
fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp)
self.logger.info(
f"Received observation #{obs_timestep} | "
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client
f"Target: {fps_metrics['target_fps']:.2f} | "
f"One-way latency: {(receive_time - obs_timestamp) * 1000:.2f}ms"
)
self.logger.debug(
f"Server timestamp: {receive_time:.6f} | "
f"Client timestamp: {obs_timestamp:.6f} | "
f"Deserialization time: {deserialize_time:.6f}s"
)
if not self._enqueue_observation(
timed_observation # wrapping a RawObservation
):
self.logger.info(f"Observation #{obs_timestep} has been filtered out")
return async_inference_pb2.Empty()
def GetActions(self, request, context): # noqa: N802
"""Returns actions to the robot client. Actions are sent as a single
chunk, containing multiple actions."""
client_id = context.peer()
self.logger.debug(f"Client {client_id} connected for action streaming")
# Generate action based on the most recent observation and its timestep
try:
getactions_starts = time.perf_counter()
obs = self.observation_queue.get(timeout=self.config.obs_queue_timeout)
self.logger.info(
f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})"
)
with self._predicted_timesteps_lock:
self._predicted_timesteps.add(obs.get_timestep())
start_time = time.perf_counter()
action_chunk = self._predict_action_chunk(obs)
inference_time = time.perf_counter() - start_time
start_time = time.perf_counter()
actions_bytes = pickle.dumps(action_chunk) # nosec
serialize_time = time.perf_counter() - start_time
# Create and return the action chunk
actions = async_inference_pb2.Actions(data=actions_bytes)
self.logger.info(
f"Action chunk #{obs.get_timestep()} generated | "
f"Total time: {(inference_time + serialize_time) * 1000:.2f}ms"
)
self.logger.debug(
f"Action chunk #{obs.get_timestep()} generated | "
f"Inference time: {inference_time:.2f}s |"
f"Serialize time: {serialize_time:.2f}s |"
f"Total time: {inference_time + serialize_time:.2f}s"
)
time.sleep(
max(0, self.config.inference_latency - max(0, time.perf_counter() - getactions_starts))
) # sleep controls inference latency
return actions
except Empty: # no observation added to queue in obs_queue_timeout
return async_inference_pb2.Empty()
except Exception as e:
self.logger.error(f"Error in StreamActions: {e}")
return async_inference_pb2.Empty()
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
"""Check if the observation is valid to be processed by the policy"""
with self._predicted_timesteps_lock:
predicted_timesteps = self._predicted_timesteps
if obs.get_timestep() in predicted_timesteps:
self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!")
return False
elif observations_similar(obs, previous_obs, lerobot_features=self.lerobot_features):
self.logger.debug(
f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!"
)
return False
else:
return True
def _enqueue_observation(self, obs: TimedObservation) -> bool:
"""Enqueue an observation if it must go through processing, otherwise skip it.
Observations not in queue are never run through the policy network"""
if (
obs.must_go
or self.last_processed_obs is None
or self._obs_sanity_checks(obs, self.last_processed_obs)
):
last_obs = self.last_processed_obs.get_timestep() if self.last_processed_obs else "None"
self.logger.debug(
f"Enqueuing observation. Must go: {obs.must_go} | Last processed obs: {last_obs}"
)
# If queue is full, get the old observation to make room
if self.observation_queue.full():
# pops from queue
_ = self.observation_queue.get_nowait()
self.logger.debug("Observation queue was full, removed oldest observation")
# Now put the new observation (never blocks as queue is non-full here)
self.observation_queue.put(obs)
return True
return False
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
"""Turn a chunk of actions into a list of TimedAction instances,
with the first action corresponding to t_0 and the rest corresponding to
t_0 + i*environment_dt for i in range(len(action_chunk))
"""
return [
TimedAction(timestamp=t_0 + i * self.config.environment_dt, timestep=i_0 + i, action=action)
for i, action in enumerate(action_chunk)
]
def _prepare_observation(self, observation_t: TimedObservation) -> Observation:
"""
Prepare observation, ready for policy inference.
E.g.: To keep observation sampling rate high (and network packet tiny) we send int8 [0,255] images from the
client and then convert them to float32 [0,1] images here, before running inference.
"""
# RawObservation from robot.get_observation() - wrong keys, wrong dtype, wrong image shape
observation: Observation = raw_observation_to_observation(
observation_t.get_observation(),
self.lerobot_features,
self.policy_image_features,
self.device,
)
# processed Observation - right keys, right dtype, right image shape
return observation
def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
"""Get an action chunk from the policy. The chunk contains only"""
chunk = self.policy.predict_action_chunk(observation)
if chunk.ndim != 3:
chunk = chunk.unsqueeze(0) # adding batch dimension, now shape is (B, chunk_size, action_dim)
return chunk[:, : self.actions_per_chunk, :] + torch.randn_like(chunk[:, : self.actions_per_chunk, :])
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
"""Predict an action chunk based on an observation"""
inference_starts = time.perf_counter()
"""1. Prepare observation"""
start_time = time.perf_counter()
observation = self._prepare_observation(observation_t)
preprocessing_time = time.perf_counter() - start_time
self.last_processed_obs: TimedObservation = observation_t
"""2. Get action chunk"""
start_time = time.perf_counter()
action_tensor = self._get_action_chunk(observation)
inference_time = time.perf_counter() - start_time
"""3. Post-inference processing"""
start_time = time.perf_counter()
# Move to CPU before serializing
action_tensor = action_tensor.cpu().squeeze(0)
action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
)
postprocessing_time = time.perf_counter() - start_time
inference_stops = time.perf_counter()
self.logger.info(
f"Observation {observation_t.get_timestep()} |"
f"Inference time: {1000 * (inference_stops - inference_starts):.2f}ms"
)
# full-process latency breakdown for debugging purposes
self.logger.debug(
f"Observation {observation_t.get_timestep()} | "
f"Preprocessing time: {1000 * (preprocessing_time - inference_starts):.2f}ms | "
f"Inference time: {1000 * (inference_time - preprocessing_time):.2f}ms | "
f"Postprocessing time: {1000 * (postprocessing_time - inference_time):.2f}ms | "
f"Total time: {1000 * (postprocessing_time - inference_starts):.2f}ms"
)
return action_chunk
def stop(self):
"""Stop the server"""
self._reset_server()
self.logger.info("Server stopping...")
@draccus.wrap()
def serve(cfg: PolicyServerConfig):
"""Start the PolicyServer with the given configuration.
Args:
config: PolicyServerConfig instance. If None, uses default configuration.
"""
logging.info(pformat(asdict(cfg)))
# Create the server instance first
policy_server = PolicyServer(cfg)
# Setup and start gRPC server
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
server.add_insecure_port(f"{cfg.host}:{cfg.port}")
policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")
server.start()
server.wait_for_termination()
policy_server.logger.info("Server terminated")
if __name__ == "__main__":
serve()

View File

@@ -0,0 +1,509 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example command:
```shell
python src/lerobot/scripts/server/robot_client.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
--robot.id=black \
--task="dummy" \
--server_address=127.0.0.1:8080 \
--policy_type=act \
--pretrained_name_or_path=user/model \
--policy_device=mps \
--actions_per_chunk=50 \
--chunk_size_threshold=0.5 \
--aggregate_fn_name=weighted_average \
--debug_visualize_queue_size=True
```
"""
import logging
import pickle # nosec
import threading
import time
from dataclasses import asdict
from pprint import pformat
from queue import Queue
from typing import Any, Callable, Optional
import draccus
import grpc
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs.policies import PreTrainedConfig
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
koch_follower,
make_robot_from_config,
so100_follower,
so101_follower,
)
from lerobot.scripts.server.configs import RobotClientConfig
from lerobot.scripts.server.constants import SUPPORTED_ROBOTS
from lerobot.scripts.server.helpers import (
Action,
FPSTracker,
Observation,
RawObservation,
RemotePolicyConfig,
TimedAction,
TimedObservation,
get_logger,
map_robot_keys_to_lerobot_features,
send_bytes_in_chunks,
validate_robot_cameras_for_policy,
visualize_action_queue_size,
)
from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
)
class RobotClient:
prefix = "robot_client"
logger = get_logger(prefix)
def __init__(self, config: RobotClientConfig):
"""Initialize RobotClient with unified configuration.
Args:
config: RobotClientConfig containing all configuration parameters
"""
# Store configuration
self.config = config
self.robot = make_robot_from_config(config.robot)
self.robot.connect()
lerobot_features = map_robot_keys_to_lerobot_features(self.robot)
if config.verify_robot_cameras:
# Load policy config for validation
policy_config = PreTrainedConfig.from_pretrained(config.pretrained_name_or_path)
policy_image_features = policy_config.image_features
# The cameras specified for inference must match the one supported by the policy chosen
validate_robot_cameras_for_policy(lerobot_features, policy_image_features)
# Use environment variable if server_address is not provided in config
self.server_address = config.server_address
self.policy_config = RemotePolicyConfig(
config.policy_type,
config.pretrained_name_or_path,
lerobot_features,
config.actions_per_chunk,
config.policy_device,
)
self.channel = grpc.insecure_channel(self.server_address)
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
self._running_event = threading.Event()
# Initialize client side variables
self.latest_action_lock = threading.Lock()
self.latest_action = -1
self.action_chunk_size = -1
self._chunk_size_threshold = config.chunk_size_threshold
self.action_queue = Queue()
self.action_queue_lock = threading.Lock() # Protect queue operations
self.action_queue_size = []
self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
# FPS measurement
self.fps_tracker = FPSTracker(target_fps=self.config.fps)
self.logger.info("Robot connected and ready")
# Use an event for thread-safe coordination
self.must_go = threading.Event()
self.must_go.set() # Initially set - observations qualify for direct processing
@property
def running(self):
return self._running_event.is_set()
def start(self):
"""Start the robot client and connect to the policy server"""
try:
# client-server handshake
start_time = time.perf_counter()
self.stub.Ready(async_inference_pb2.Empty())
end_time = time.perf_counter()
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
# send policy instructions
policy_config_bytes = pickle.dumps(self.policy_config)
policy_setup = async_inference_pb2.PolicySetup(data=policy_config_bytes)
self.logger.info("Sending policy instructions to policy server")
self.logger.debug(
f"Policy type: {self.policy_config.policy_type} | "
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
f"Device: {self.policy_config.device}"
)
self.stub.SendPolicyInstructions(policy_setup)
self._running_event.set()
return True
except grpc.RpcError as e:
self.logger.error(f"Failed to connect to policy server: {e}")
return False
def stop(self):
"""Stop the robot client"""
self._running_event.clear()
self.robot.disconnect()
self.logger.debug("Robot disconnected")
self.channel.close()
self.logger.debug("Client stopped, channel closed")
def send_observation(
self,
obs: TimedObservation,
) -> bool:
"""Send observation to the policy server.
Returns True if the observation was sent successfully, False otherwise."""
if not self.running:
raise RuntimeError("Client not running. Run RobotClient.start() before sending observations.")
if not isinstance(obs, TimedObservation):
raise ValueError("Input observation needs to be a TimedObservation!")
start_time = time.perf_counter()
observation_bytes = pickle.dumps(obs)
serialize_time = time.perf_counter() - start_time
self.logger.debug(f"Observation serialization time: {serialize_time:.6f}s")
try:
observation_iterator = send_bytes_in_chunks(
observation_bytes,
async_inference_pb2.Observation,
log_prefix="[CLIENT] Observation",
silent=True,
)
_ = self.stub.SendObservations(observation_iterator)
obs_timestep = obs.get_timestep()
self.logger.info(f"Sent observation #{obs_timestep} | ")
return True
except grpc.RpcError as e:
self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
return False
def _inspect_action_queue(self):
with self.action_queue_lock:
queue_size = self.action_queue.qsize()
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
return queue_size, timestamps
def _aggregate_action_queues(
self,
incoming_actions: list[TimedAction],
aggregate_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
"""Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
if aggregate_fn is None:
# default aggregate function: take the latest action
def aggregate_fn(x1, x2):
return x2
future_action_queue = Queue()
with self.action_queue_lock:
internal_queue = self.action_queue.queue
current_action_queue = {action.get_timestep(): action.get_action() for action in internal_queue}
for new_action in incoming_actions:
with self.latest_action_lock:
latest_action = self.latest_action
# New action is older than the latest action in the queue, skip it
if new_action.get_timestep() <= latest_action:
continue
# If the new action's timestep is not in the current action queue, add it directly
elif new_action.get_timestep() not in current_action_queue:
future_action_queue.put(new_action)
continue
# If the new action's timestep is in the current action queue, aggregate it
# TODO: There is probably a way to do this with broadcasting of the two action tensors
future_action_queue.put(
TimedAction(
timestamp=new_action.get_timestamp(),
timestep=new_action.get_timestep(),
action=aggregate_fn(
current_action_queue[new_action.get_timestep()], new_action.get_action()
),
)
)
with self.action_queue_lock:
self.action_queue = future_action_queue
def receive_actions(self, verbose: bool = False):
"""Receive actions from the policy server"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
self.logger.info("Action receiving thread starting")
while self.running:
try:
# Use StreamActions to get a stream of actions from the server
actions_chunk = self.stub.GetActions(async_inference_pb2.Empty())
if len(actions_chunk.data) == 0:
continue # received `Empty` from server, wait for next call
receive_time = time.time()
# Deserialize bytes back into list[TimedAction]
deserialize_start = time.perf_counter()
timed_actions = pickle.loads(actions_chunk.data) # nosec
deserialize_time = time.perf_counter() - deserialize_start
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
# Calculate network latency if we have matching observations
if len(timed_actions) > 0 and verbose:
with self.latest_action_lock:
latest_action = self.latest_action
self.logger.debug(f"Current latest action: {latest_action}")
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:
old_timesteps = [latest_action] # queue was empty
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
if not old_timesteps:
old_timesteps = [latest_action] # queue was empty
# Log incoming actions
incoming_timesteps = [a.get_timestep() for a in timed_actions]
first_action_timestep = timed_actions[0].get_timestep()
server_to_client_latency = (receive_time - timed_actions[0].get_timestamp()) * 1000
self.logger.info(
f"Received action chunk for step #{first_action_timestep} | "
f"Latest action: #{latest_action} | "
f"Incoming actions: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
f"Network latency (server->client): {server_to_client_latency:.2f}ms | "
f"Deserialization time: {deserialize_time * 1000:.2f}ms"
)
# Update action queue
start_time = time.perf_counter()
self._aggregate_action_queues(timed_actions, self.config.aggregate_fn)
queue_update_time = time.perf_counter() - start_time
self.must_go.set() # after receiving actions, next empty queue triggers must-go processing!
if verbose:
# Get queue state after changes
new_size, new_timesteps = self._inspect_action_queue()
with self.latest_action_lock:
latest_action = self.latest_action
self.logger.info(
f"Latest action: {latest_action} | "
f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
)
self.logger.debug(
f"Queue update complete ({queue_update_time:.6f}s) | "
f"Before: {old_size} items | "
f"After: {new_size} items | "
)
except grpc.RpcError as e:
self.logger.error(f"Error receiving actions: {e}")
def actions_available(self):
"""Check if there are actions available in the queue"""
with self.action_queue_lock:
return not self.action_queue.empty()
def _action_tensor_to_action_dict(self, action_tensor: torch.Tensor) -> dict[str, float]:
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
return action
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
"""Reading and performing actions in local queue"""
# Lock only for queue operations
get_start = time.perf_counter()
with self.action_queue_lock:
self.action_queue_size.append(self.action_queue.qsize())
# Get action from queue
timed_action = self.action_queue.get_nowait()
get_end = time.perf_counter() - get_start
_performed_action = self.robot.send_action(
self._action_tensor_to_action_dict(timed_action.get_action())
)
with self.latest_action_lock:
self.latest_action = timed_action.get_timestep()
if verbose:
with self.action_queue_lock:
current_queue_size = self.action_queue.qsize()
self.logger.debug(
f"Ts={timed_action.get_timestamp()} | "
f"Action #{timed_action.get_timestep()} performed | "
f"Queue size: {current_queue_size}"
)
self.logger.debug(
f"Popping action from queue to perform took {get_end:.6f}s | Queue size: {current_queue_size}"
)
return _performed_action
def _ready_to_send_observation(self):
"""Flags when the client is ready to send an observation"""
with self.action_queue_lock:
return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
def control_loop_observation(self, task: str, verbose: bool = False) -> RawObservation:
try:
# Get serialized observation bytes from the function
start_time = time.perf_counter()
raw_observation: RawObservation = self.robot.get_observation()
raw_observation["task"] = task
with self.latest_action_lock:
latest_action = self.latest_action
observation = TimedObservation(
timestamp=time.time(), # need time.time() to compare timestamps across client and server
observation=raw_observation,
timestep=max(latest_action, 0),
)
obs_capture_time = time.perf_counter() - start_time
# If there are no actions left in the queue, the observation must go through processing!
with self.action_queue_lock:
observation.must_go = self.must_go.is_set() and self.action_queue.empty()
current_queue_size = self.action_queue.qsize()
_ = self.send_observation(observation)
self.logger.debug(f"QUEUE SIZE: {current_queue_size} (Must go: {observation.must_go})")
if observation.must_go:
# must-go event will be set again after receiving actions
self.must_go.clear()
if verbose:
# Calculate comprehensive FPS metrics
fps_metrics = self.fps_tracker.calculate_fps_metrics(observation.get_timestamp())
self.logger.info(
f"Obs #{observation.get_timestep()} | "
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | "
f"Target: {fps_metrics['target_fps']:.2f}"
)
self.logger.debug(
f"Ts={observation.get_timestamp():.6f} | Capturing observation took {obs_capture_time:.6f}s"
)
return raw_observation
except Exception as e:
self.logger.error(f"Error in observation sender: {e}")
def control_loop(self, task: str, verbose: bool = False) -> tuple[Observation, Action]:
"""Combined function for executing actions and streaming observations"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
self.logger.info("Control loop thread starting")
_performed_action = None
_captured_observation = None
while self.running:
control_loop_start = time.perf_counter()
"""Control loop: (1) Performing actions, when available"""
if self.actions_available():
_performed_action = self.control_loop_action(verbose)
"""Control loop: (2) Streaming observations to the remote policy server"""
if self._ready_to_send_observation():
_captured_observation = self.control_loop_observation(task, verbose)
self.logger.info(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
# Dynamically adjust sleep time to maintain the desired control frequency
time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start)))
return _captured_observation, _performed_action
@draccus.wrap()
def async_client(cfg: RobotClientConfig):
logging.info(pformat(asdict(cfg)))
if cfg.robot.type not in SUPPORTED_ROBOTS:
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
client = RobotClient(cfg)
if client.start():
client.logger.info("Starting action receiver thread...")
# Create and start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
# Start action receiver thread
action_receiver_thread.start()
try:
# The main thread runs the control loop
client.control_loop(task=cfg.task)
finally:
client.stop()
action_receiver_thread.join()
if cfg.debug_visualize_queue_size:
visualize_action_queue_size(client.action_queue_size)
client.logger.info("Client stopped")
if __name__ == "__main__":
async_client() # run the client

View File

@@ -0,0 +1,59 @@
// fmt: off
// flake8: noqa
// !/usr/bin/env python
// Copyright 2024 The HuggingFace Inc. team.
// All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package async_inference;
// AsyncInference: from Robot perspective
// Robot send observations to & executes action received from a remote Policy server
service AsyncInference {
// Robot -> Policy to share observations with a remote inference server
// Policy -> Robot to share actions predicted for given observations
rpc SendObservations(stream Observation) returns (Empty);
rpc GetActions(Empty) returns (Actions);
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
rpc Ready(Empty) returns (Empty);
rpc Stop(Empty) returns (Empty);
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
TRANSFER_MIDDLE = 2;
TRANSFER_END = 3;
}
// Messages
message Observation {
// sent by Robot, to remote Policy
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
bytes data = 2;
}
message Actions {
// sent by remote Policy, to Robot
bytes data = 1;
}
message PolicySetup {
// sent by Robot to remote server, to init Policy
bytes data = 1;
}
message Empty {}

View File

@@ -0,0 +1,45 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: async_inference.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'async_inference.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdd\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12>\n\nGetActions\x12\x16.async_inference.Empty\x1a\x18.async_inference.Actions\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Empty\x12\x36\n\x04Stop\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=190
_globals['_TRANSFERSTATE']._serialized_end=286
_globals['_OBSERVATION']._serialized_start=42
_globals['_OBSERVATION']._serialized_end=125
_globals['_ACTIONS']._serialized_start=127
_globals['_ACTIONS']._serialized_end=150
_globals['_POLICYSETUP']._serialized_start=152
_globals['_POLICYSETUP']._serialized_end=179
_globals['_EMPTY']._serialized_start=181
_globals['_EMPTY']._serialized_end=188
_globals['_ASYNCINFERENCE']._serialized_start=289
_globals['_ASYNCINFERENCE']._serialized_end=638
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,277 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
from lerobot.transport import async_inference_pb2 as async__inference__pb2
GRPC_GENERATED_VERSION = '1.71.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class AsyncInferenceStub:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendObservations = channel.stream_unary(
'/async_inference.AsyncInference/SendObservations',
request_serializer=async__inference__pb2.Observation.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.GetActions = channel.unary_unary(
'/async_inference.AsyncInference/GetActions',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Actions.FromString,
_registered_method=True)
self.SendPolicyInstructions = channel.unary_unary(
'/async_inference.AsyncInference/SendPolicyInstructions',
request_serializer=async__inference__pb2.PolicySetup.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/async_inference.AsyncInference/Ready',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
self.Stop = channel.unary_unary(
'/async_inference.AsyncInference/Stop',
request_serializer=async__inference__pb2.Empty.SerializeToString,
response_deserializer=async__inference__pb2.Empty.FromString,
_registered_method=True)
class AsyncInferenceServicer:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
def SendObservations(self, request_iterator, context):
"""Robot -> Policy to share observations with a remote inference server
Policy -> Robot to share actions predicted for given observations
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetActions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendPolicyInstructions(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Stop(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_AsyncInferenceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendObservations': grpc.stream_unary_rpc_method_handler(
servicer.SendObservations,
request_deserializer=async__inference__pb2.Observation.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'GetActions': grpc.unary_unary_rpc_method_handler(
servicer.GetActions,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Actions.SerializeToString,
),
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
servicer.SendPolicyInstructions,
request_deserializer=async__inference__pb2.PolicySetup.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
'Stop': grpc.unary_unary_rpc_method_handler(
servicer.Stop,
request_deserializer=async__inference__pb2.Empty.FromString,
response_serializer=async__inference__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'async_inference.AsyncInference', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class AsyncInference:
"""AsyncInference: from Robot perspective
Robot send observations to & executes action received from a remote Policy server
"""
@staticmethod
def SendObservations(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/async_inference.AsyncInference/SendObservations',
async__inference__pb2.Observation.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def GetActions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/GetActions',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Actions.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendPolicyInstructions(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/SendPolicyInstructions',
async__inference__pb2.PolicySetup.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Ready',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Stop(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/async_inference.AsyncInference/Stop',
async__inference__pb2.Empty.SerializeToString,
async__inference__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -0,0 +1,177 @@
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end test of the asynchronous inference stack (client ↔ server).
This test spins up a lightweight gRPC `PolicyServer` instance with a stubbed
policy network and launches a `RobotClient` that uses a `MockRobot`. The goal
is to exercise the full communication loop:
1. Client sends policy specification → Server
2. Client streams observations → Server
3. Server streams action chunks → Client
4. Client executes received actions
The test succeeds if at least one action is executed and the server records at
least one predicted timestep - demonstrating that the gRPC round-trip works
end-to-end using real (but lightweight) protocol messages.
"""
from __future__ import annotations
import threading
from concurrent import futures
import pytest
import torch
# Skip entire module if grpc is not available
pytest.importorskip("grpc")
# -----------------------------------------------------------------------------
# End-to-end test
# -----------------------------------------------------------------------------
def test_async_inference_e2e(monkeypatch):
"""Tests the full asynchronous inference pipeline."""
# Import grpc-dependent modules inside the test function
import grpc
from lerobot.robots.utils import make_robot_from_config
from lerobot.scripts.server.configs import PolicyServerConfig, RobotClientConfig
from lerobot.scripts.server.helpers import map_robot_keys_to_lerobot_features
from lerobot.scripts.server.policy_server import PolicyServer
from lerobot.scripts.server.robot_client import RobotClient
from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
)
from tests.mocks.mock_robot import MockRobotConfig
# Create a stub policy similar to test_policy_server.py
class MockPolicy:
"""A minimal mock for an actual policy, returning zeros."""
class _Config:
robot_type = "dummy_robot"
@property
def image_features(self):
"""Empty image features since this test doesn't use images."""
return {}
def __init__(self):
self.config = self._Config()
def to(self, *args, **kwargs):
return self
def model(self, batch):
# Return a chunk of 20 dummy actions.
batch_size = len(batch["robot_type"])
return torch.zeros(batch_size, 20, 6)
# ------------------------------------------------------------------
# 1. Create PolicyServer instance with mock policy
# ------------------------------------------------------------------
policy_server_config = PolicyServerConfig(host="localhost", port=9999)
policy_server = PolicyServer(policy_server_config)
# Replace the real policy with our fast, deterministic stub.
policy_server.policy = MockPolicy()
policy_server.actions_per_chunk = 20
policy_server.device = "cpu"
# Set up robot config and features
robot_config = MockRobotConfig()
mock_robot = make_robot_from_config(robot_config)
lerobot_features = map_robot_keys_to_lerobot_features(mock_robot)
policy_server.lerobot_features = lerobot_features
# Force server to produce deterministic action chunks in test mode
policy_server.policy_type = "act"
def _fake_get_action_chunk(_self, _obs, _type="test"):
action_dim = 6
batch_size = 1
actions_per_chunk = policy_server.actions_per_chunk
return torch.zeros(batch_size, actions_per_chunk, action_dim)
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
# Bypass potentially heavy model loading inside SendPolicyInstructions
def _fake_send_policy_instructions(self, request, context): # noqa: N802
return async_inference_pb2.Empty()
monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True)
# Build gRPC server running a PolicyServer
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server"))
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
# Use the host/port specified in the fixture's config
server_address = f"{policy_server.config.host}:{policy_server.config.port}"
server.add_insecure_port(server_address)
server.start()
# ------------------------------------------------------------------
# 2. Create a RobotClient around the MockRobot
# ------------------------------------------------------------------
client_config = RobotClientConfig(
server_address=server_address,
robot=robot_config,
chunk_size_threshold=0.0,
policy_type="test",
pretrained_name_or_path="test",
actions_per_chunk=20,
verify_robot_cameras=False,
)
client = RobotClient(client_config)
assert client.start(), "Client failed initial handshake with the server"
# Track action chunks received without modifying RobotClient
action_chunks_received = {"count": 0}
original_aggregate = client._aggregate_action_queues
def counting_aggregate(*args, **kwargs):
action_chunks_received["count"] += 1
return original_aggregate(*args, **kwargs)
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)
# Start client threads
action_thread = threading.Thread(target=client.receive_actions, daemon=True)
control_thread = threading.Thread(target=client.control_loop, args=({"task": ""}), daemon=True)
action_thread.start()
control_thread.start()
# ------------------------------------------------------------------
# 3. System exchanges a few messages
# ------------------------------------------------------------------
# Wait for 5 seconds
server.wait_for_termination(timeout=5)
assert action_chunks_received["count"] > 0, "Client did not receive any action chunks"
assert len(policy_server._predicted_timesteps) > 0, "Server did not record any predicted timesteps"
# ------------------------------------------------------------------
# 4. Stop the system
# ------------------------------------------------------------------
client.stop()
action_thread.join()
control_thread.join()
policy_server.stop()
server.stop(grace=None)

View File

@@ -0,0 +1,459 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import pickle
import time
import numpy as np
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.scripts.server.helpers import (
FPSTracker,
TimedAction,
TimedObservation,
observations_similar,
prepare_image,
prepare_raw_observation,
raw_observation_to_observation,
resize_robot_observation_image,
)
# ---------------------------------------------------------------------
# FPSTracker
# ---------------------------------------------------------------------
def test_fps_tracker_first_observation():
"""First observation should initialize timestamp and return 0 FPS."""
tracker = FPSTracker(target_fps=30.0)
timestamp = 1000.0
metrics = tracker.calculate_fps_metrics(timestamp)
assert tracker.first_timestamp == timestamp
assert tracker.total_obs_count == 1
assert metrics["avg_fps"] == 0.0
assert metrics["target_fps"] == 30.0
def test_fps_tracker_single_interval():
"""Two observations 1 second apart should give 1 FPS."""
tracker = FPSTracker(target_fps=30.0)
# First observation at t=0
metrics1 = tracker.calculate_fps_metrics(0.0)
assert metrics1["avg_fps"] == 0.0
# Second observation at t=1 (1 second later)
metrics2 = tracker.calculate_fps_metrics(1.0)
expected_fps = 1.0 # (2-1) observations / 1.0 seconds = 1 FPS
assert math.isclose(metrics2["avg_fps"], expected_fps, rel_tol=1e-6)
def test_fps_tracker_multiple_intervals():
"""Multiple observations should calculate correct average FPS."""
tracker = FPSTracker(target_fps=30.0)
# Simulate 5 observations over 2 seconds (should be 2 FPS average)
timestamps = [0.0, 0.5, 1.0, 1.5, 2.0]
for i, ts in enumerate(timestamps):
metrics = tracker.calculate_fps_metrics(ts)
if i == 0:
assert metrics["avg_fps"] == 0.0
elif i == len(timestamps) - 1:
# After 5 observations over 2 seconds: (5-1)/2 = 2 FPS
expected_fps = 2.0
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
def test_fps_tracker_irregular_intervals():
"""FPS calculation should work with irregular time intervals."""
tracker = FPSTracker(target_fps=30.0)
# Irregular timestamps: 0, 0.1, 0.5, 2.0, 3.0 seconds
timestamps = [0.0, 0.1, 0.5, 2.0, 3.0]
for ts in timestamps:
metrics = tracker.calculate_fps_metrics(ts)
# 5 observations over 3 seconds: (5-1)/3 = 1.333... FPS
expected_fps = 4.0 / 3.0
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
# ---------------------------------------------------------------------
# TimedData helpers
# ---------------------------------------------------------------------
def test_timed_action_getters():
"""TimedAction stores & returns timestamp, action tensor and timestep."""
ts = time.time()
action = torch.arange(10)
ta = TimedAction(timestamp=ts, action=action, timestep=0)
assert math.isclose(ta.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
torch.testing.assert_close(ta.get_action(), action)
assert ta.get_timestep() == 0
def test_timed_observation_getters():
"""TimedObservation stores & returns timestamp, dict and timestep."""
ts = time.time()
obs_dict = {"observation.state": torch.ones(6)}
to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0)
assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
assert to.get_observation() is obs_dict
assert to.get_timestep() == 0
def test_timed_data_deserialization_data_getters():
"""TimedAction / TimedObservation survive a round-trip through ``pickle``.
The async-inference stack uses ``pickle.dumps`` to move these objects across
the gRPC boundary (see RobotClient.send_observation and PolicyServer.StreamActions).
This test ensures that the payload keeps its content intact after
the (de)serialization round-trip.
"""
ts = time.time()
# ------------------------------------------------------------------
# TimedAction
# ------------------------------------------------------------------
original_action = torch.randn(6)
ta_in = TimedAction(timestamp=ts, action=original_action, timestep=13)
# Serialize → bytes → deserialize
ta_bytes = pickle.dumps(ta_in) # nosec
ta_out: TimedAction = pickle.loads(ta_bytes) # nosec B301
# Identity & content checks
assert math.isclose(ta_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
assert ta_out.get_timestep() == 13
torch.testing.assert_close(ta_out.get_action(), original_action)
# ------------------------------------------------------------------
# TimedObservation
# ------------------------------------------------------------------
obs_dict = {"observation.state": torch.arange(4).float()}
to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True)
to_bytes = pickle.dumps(to_in) # nosec
to_out: TimedObservation = pickle.loads(to_bytes) # nosec B301
assert math.isclose(to_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
assert to_out.get_timestep() == 7
assert to_out.must_go is True
assert to_out.get_observation().keys() == obs_dict.keys()
torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"])
# ---------------------------------------------------------------------
# observations_similar()
# ---------------------------------------------------------------------
def _make_obs(state: torch.Tensor) -> TimedObservation:
"""Create a TimedObservation with raw robot observation format."""
return TimedObservation(
timestamp=time.time(),
observation={
"shoulder": state[0].item() if len(state) > 0 else 0.0,
"elbow": state[1].item() if len(state) > 1 else 0.0,
"wrist": state[2].item() if len(state) > 2 else 0.0,
"gripper": state[3].item() if len(state) > 3 else 0.0,
},
timestep=0,
)
def test_observations_similar_true():
"""Distance below atol → observations considered similar."""
# Create mock lerobot features for the similarity check
lerobot_features = {
"observation.state": {
"dtype": "float32",
"shape": [4],
"names": ["shoulder", "elbow", "wrist", "gripper"],
}
}
obs1 = _make_obs(torch.zeros(4))
obs2 = _make_obs(0.5 * torch.ones(4))
assert observations_similar(obs1, obs2, lerobot_features, atol=2.0)
obs3 = _make_obs(2.0 * torch.ones(4))
assert not observations_similar(obs1, obs3, lerobot_features, atol=2.0)
# ---------------------------------------------------------------------
# raw_observation_to_observation and helpers
# ---------------------------------------------------------------------
def _create_mock_robot_observation():
"""Create a mock robot observation with motor positions and camera images."""
return {
"shoulder": 1.0,
"elbow": 2.0,
"wrist": 3.0,
"gripper": 0.5,
"laptop": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
"phone": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
}
def _create_mock_lerobot_features():
"""Create mock lerobot features mapping similar to what hw_to_dataset_features returns."""
return {
"observation.state": {
"dtype": "float32",
"shape": [4],
"names": ["shoulder", "elbow", "wrist", "gripper"],
},
"observation.images.laptop": {
"dtype": "image",
"shape": [480, 640, 3],
"names": ["height", "width", "channels"],
},
"observation.images.phone": {
"dtype": "image",
"shape": [480, 640, 3],
"names": ["height", "width", "channels"],
},
}
def _create_mock_policy_image_features():
"""Create mock policy image features with different resolutions."""
return {
"observation.images.laptop": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224), # Policy expects smaller resolution
),
"observation.images.phone": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 160, 160), # Different resolution for second camera
),
}
def test_prepare_image():
"""Test image preprocessing: int8 → float32, normalization to [0,1]."""
# Create mock int8 image data
image_int8 = torch.randint(0, 256, size=(3, 224, 224), dtype=torch.uint8)
processed = prepare_image(image_int8)
# Check dtype conversion
assert processed.dtype == torch.float32
# Check normalization range
assert processed.min() >= 0.0
assert processed.max() <= 1.0
# Check that values are scaled correctly (255 → 1.0, 0 → 0.0)
if image_int8.max() == 255:
assert torch.isclose(processed.max(), torch.tensor(1.0), atol=1e-6)
if image_int8.min() == 0:
assert torch.isclose(processed.min(), torch.tensor(0.0), atol=1e-6)
# Check memory contiguity
assert processed.is_contiguous()
def test_resize_robot_observation_image():
"""Test image resizing from robot resolution to policy resolution."""
# Create mock image: (H=480, W=640, C=3)
original_image = torch.randint(0, 256, size=(480, 640, 3), dtype=torch.uint8)
target_shape = (3, 224, 224) # (C, H, W)
resized = resize_robot_observation_image(original_image, target_shape)
# Check output shape matches target
assert resized.shape == target_shape
# Check that original image had different dimensions
assert original_image.shape != resized.shape
# Check that resizing preserves value range
assert resized.min() >= 0
assert resized.max() <= 255
def test_prepare_raw_observation():
"""Test the preparation of raw robot observation to lerobot format."""
robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features)
# Check that state is properly extracted and batched
assert "observation.state" in prepared
state = prepared["observation.state"]
assert isinstance(state, torch.Tensor)
assert state.shape == (1, 4) # Batched state
# Check that images are processed and resized
assert "observation.images.laptop" in prepared
assert "observation.images.phone" in prepared
laptop_img = prepared["observation.images.laptop"]
phone_img = prepared["observation.images.phone"]
# Check image shapes match policy requirements
assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape
assert phone_img.shape == policy_image_features["observation.images.phone"].shape
# Check that images are tensors
assert isinstance(laptop_img, torch.Tensor)
assert isinstance(phone_img, torch.Tensor)
def test_raw_observation_to_observation_basic():
"""Test the main raw_observation_to_observation function."""
robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
device = "cpu"
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
# Check that all expected keys are present
assert "observation.state" in observation
assert "observation.images.laptop" in observation
assert "observation.images.phone" in observation
# Check state processing
state = observation["observation.state"]
assert isinstance(state, torch.Tensor)
assert state.device.type == device
assert state.shape == (1, 4) # Batched
# Check image processing
laptop_img = observation["observation.images.laptop"]
phone_img = observation["observation.images.phone"]
# Images should have batch dimension: (B, C, H, W)
assert laptop_img.shape == (1, 3, 224, 224)
assert phone_img.shape == (1, 3, 160, 160)
# Check device placement
assert laptop_img.device.type == device
assert phone_img.device.type == device
# Check image dtype and range (should be float32 in [0, 1])
assert laptop_img.dtype == torch.float32
assert phone_img.dtype == torch.float32
assert laptop_img.min() >= 0.0 and laptop_img.max() <= 1.0
assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0
def test_raw_observation_to_observation_with_non_tensor_data():
"""Test that non-tensor data (like task strings) is preserved."""
robot_obs = _create_mock_robot_observation()
robot_obs["task"] = "pick up the red cube" # Add string instruction
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
device = "cpu"
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
# Check that task string is preserved
assert "task" in observation
assert observation["task"] == "pick up the red cube"
assert isinstance(observation["task"], str)
@torch.no_grad()
def test_raw_observation_to_observation_device_handling():
"""Test that tensors are properly moved to the specified device."""
device = "mps" if torch.backends.mps.is_available() else "cpu"
robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
# Check that all tensors are on the correct device
for key, value in observation.items():
if isinstance(value, torch.Tensor):
assert value.device.type == device, f"Tensor {key} not on {device}"
def test_raw_observation_to_observation_deterministic():
"""Test that the function produces consistent results for the same input."""
robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
device = "cpu"
# Run twice with same input
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
# Results should be identical
assert set(obs1.keys()) == set(obs2.keys())
for key in obs1:
if isinstance(obs1[key], torch.Tensor):
torch.testing.assert_close(obs1[key], obs2[key])
else:
assert obs1[key] == obs2[key]
def test_image_processing_pipeline_preserves_content():
"""Test that the image processing pipeline preserves recognizable patterns."""
# Create an image with a specific pattern
original_img = np.zeros((100, 100, 3), dtype=np.uint8)
original_img[25:75, 25:75, :] = 255 # White square in center
robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img}
lerobot_features = {
"observation.state": {
"dtype": "float32",
"shape": [4],
"names": ["shoulder", "elbow", "wrist", "gripper"],
},
"observation.images.laptop": {
"dtype": "image",
"shape": [100, 100, 3],
"names": ["height", "width", "channels"],
},
}
policy_image_features = {
"observation.images.laptop": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 50, 50), # Downsamples from 100x100
)
}
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu")
processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim
# Check that the center region has higher values than corners
# Due to bilinear interpolation, exact values will change but pattern should remain
center_val = processed_img[:, 25, 25].mean() # Center of 50x50 image
corner_val = processed_img[:, 5, 5].mean() # Corner
assert center_val > corner_val, "Image processing should preserve recognizable patterns"

View File

@@ -0,0 +1,215 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit-tests for the `PolicyServer` core logic.
Monkey-patch the `policy` attribute with a stub so that no real model inference is performed.
"""
from __future__ import annotations
import time
import pytest
import torch
from lerobot.configs.types import PolicyFeature
from tests.utils import require_package
# -----------------------------------------------------------------------------
# Test fixtures
# -----------------------------------------------------------------------------
class MockPolicy:
"""A minimal mock for an actual policy, returning zeros.
Refer to tests/policies for tests of the individual policies supported."""
class _Config:
robot_type = "dummy_robot"
@property
def image_features(self) -> dict[str, PolicyFeature]:
"""Empty image features since this test doesn't use images."""
return {}
def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return a chunk of 20 dummy actions."""
batch_size = len(observation["observation.state"])
return torch.zeros(batch_size, 20, 6)
def __init__(self):
self.config = self._Config()
def to(self, *args, **kwargs):
# The server calls `policy.to(device)`. This stub ignores it.
return self
def model(self, batch: dict) -> torch.Tensor:
# Return a chunk of 20 dummy actions.
batch_size = len(batch["robot_type"])
return torch.zeros(batch_size, 20, 6)
@pytest.fixture
@require_package("grpc")
def policy_server():
"""Fresh `PolicyServer` instance with a stubbed-out policy model."""
# Import only when the test actually runs (after decorator check)
from lerobot.scripts.server.configs import PolicyServerConfig
from lerobot.scripts.server.policy_server import PolicyServer
test_config = PolicyServerConfig(host="localhost", port=9999)
server = PolicyServer(test_config)
# Replace the real policy with our fast, deterministic stub.
server.policy = MockPolicy()
server.actions_per_chunk = 20
server.device = "cpu"
# Add mock lerobot_features that the observation similarity functions need
server.lerobot_features = {
"observation.state": {
"dtype": "float32",
"shape": [6],
"names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"],
}
}
return server
# -----------------------------------------------------------------------------
# Helper utilities for tests
# -----------------------------------------------------------------------------
def _make_obs(state: torch.Tensor, timestep: int = 0, must_go: bool = False):
"""Create a TimedObservation with a given state vector."""
# Import only when needed
from lerobot.scripts.server.helpers import TimedObservation
return TimedObservation(
observation={
"joint1": state[0].item() if len(state) > 0 else 0.0,
"joint2": state[1].item() if len(state) > 1 else 0.0,
"joint3": state[2].item() if len(state) > 2 else 0.0,
"joint4": state[3].item() if len(state) > 3 else 0.0,
"joint5": state[4].item() if len(state) > 4 else 0.0,
"joint6": state[5].item() if len(state) > 5 else 0.0,
},
timestamp=time.time(),
timestep=timestep,
must_go=must_go,
)
# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------
def test_time_action_chunk(policy_server):
"""Verify that `_time_action_chunk` assigns correct timestamps and timesteps."""
start_ts = time.time()
start_t = 10
# A chunk of 3 action tensors.
action_tensors = [torch.randn(6) for _ in range(3)]
timed_actions = policy_server._time_action_chunk(start_ts, action_tensors, start_t)
assert len(timed_actions) == 3
# Check timesteps
assert [ta.get_timestep() for ta in timed_actions] == [10, 11, 12]
# Check timestamps
expected_timestamps = [
start_ts,
start_ts + policy_server.config.environment_dt,
start_ts + 2 * policy_server.config.environment_dt,
]
for ta, expected_ts in zip(timed_actions, expected_timestamps, strict=True):
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
def test_maybe_enqueue_observation_must_go(policy_server):
"""An observation with `must_go=True` is always enqueued."""
obs = _make_obs(torch.zeros(6), must_go=True)
assert policy_server._enqueue_observation(obs) is True
assert policy_server.observation_queue.qsize() == 1
assert policy_server.observation_queue.get_nowait() is obs
def test_maybe_enqueue_observation_dissimilar(policy_server):
"""A dissimilar observation (not `must_go`) is enqueued."""
# Set a last predicted observation.
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
# Create a new, dissimilar observation.
new_obs = _make_obs(torch.ones(6) * 5) # High norm difference
assert policy_server._enqueue_observation(new_obs) is True
assert policy_server.observation_queue.qsize() == 1
def test_maybe_enqueue_observation_is_skipped(policy_server):
"""A similar observation (not `must_go`) is skipped."""
# Set a last predicted observation.
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
# Create a new, very similar observation.
new_obs = _make_obs(torch.zeros(6) + 1e-4)
assert policy_server._enqueue_observation(new_obs) is False
assert policy_server.observation_queue.empty() is True
def test_obs_sanity_checks(policy_server):
"""Unit-test the private `_obs_sanity_checks` helper."""
prev = _make_obs(torch.zeros(6), timestep=0)
# Case 1 timestep already predicted
policy_server._predicted_timesteps.add(1)
obs_same_ts = _make_obs(torch.ones(6), timestep=1)
assert policy_server._obs_sanity_checks(obs_same_ts, prev) is False
# Case 2 observation too similar
policy_server._predicted_timesteps.clear()
obs_similar = _make_obs(torch.zeros(6) + 1e-4, timestep=2)
assert policy_server._obs_sanity_checks(obs_similar, prev) is False
# Case 3 genuinely new & dissimilar observation passes
obs_ok = _make_obs(torch.ones(6) * 5, timestep=3)
assert policy_server._obs_sanity_checks(obs_ok, prev) is True
def test_predict_action_chunk(monkeypatch, policy_server):
"""End-to-end test of `_predict_action_chunk` with a stubbed _get_action_chunk."""
# Import only when needed
from lerobot.scripts.server.policy_server import PolicyServer
# Force server to act-style policy; patch method to return deterministic tensor
policy_server.policy_type = "act"
action_dim = 6
batch_size = 1
actions_per_chunk = policy_server.actions_per_chunk
def _fake_get_action_chunk(_self, _obs, _type="act"):
return torch.zeros(batch_size, actions_per_chunk, action_dim)
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
obs = _make_obs(torch.zeros(6), timestep=5)
timed_actions = policy_server._predict_action_chunk(obs)
assert len(timed_actions) == actions_per_chunk
assert [ta.get_timestep() for ta in timed_actions] == list(range(5, 5 + actions_per_chunk))
for i, ta in enumerate(timed_actions):
expected_ts = obs.get_timestamp() + i * policy_server.config.environment_dt
assert abs(ta.get_timestamp() - expected_ts) < 1e-6

View File

@@ -0,0 +1,234 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit-tests for the `RobotClient` action-queue logic (pure Python, no gRPC).
We monkey-patch `lerobot.common.robot_devices.robots.utils.make_robot` so that
no real hardware is accessed. Only the queue-update mechanism is verified.
"""
from __future__ import annotations
import time
from queue import Queue
import pytest
import torch
# Skip entire module if grpc is not available
pytest.importorskip("grpc")
# -----------------------------------------------------------------------------
# Test fixtures
# -----------------------------------------------------------------------------
@pytest.fixture()
def robot_client():
"""Fresh `RobotClient` instance for each test case (no threads started).
Uses DummyRobot."""
# Import only when the test actually runs (after decorator check)
from lerobot.scripts.server.configs import RobotClientConfig
from lerobot.scripts.server.robot_client import RobotClient
from tests.mocks.mock_robot import MockRobotConfig
test_config = MockRobotConfig()
# gRPC channel is not actually used in tests, so using a dummy address
test_config = RobotClientConfig(
robot=test_config,
server_address="localhost:9999",
policy_type="test",
pretrained_name_or_path="test",
actions_per_chunk=20,
verify_robot_cameras=False,
)
client = RobotClient(test_config)
# Initialize attributes that are normally set in start() method
client.chunks_received = 0
client.available_actions_size = []
yield client
if client.robot.is_connected:
client.stop()
# -----------------------------------------------------------------------------
# Helper utilities for tests
# -----------------------------------------------------------------------------
def _make_actions(start_ts: float, start_t: int, count: int):
"""Generate `count` consecutive TimedAction objects starting at timestep `start_t`."""
from lerobot.scripts.server.helpers import TimedAction
fps = 30 # emulates most common frame-rate
actions = []
for i in range(count):
timestep = start_t + i
timestamp = start_ts + i * (1 / fps)
action_tensor = torch.full((6,), timestep, dtype=torch.float32)
actions.append(TimedAction(action=action_tensor, timestep=timestep, timestamp=timestamp))
return actions
# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------
def test_update_action_queue_discards_stale(robot_client):
"""`_update_action_queue` must drop actions with `timestep` <= `latest_action`."""
# Pretend we already executed up to action #4
robot_client.latest_action = 4
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
robot_client._aggregate_action_queues(incoming)
# Extract timesteps from queue
resulting_timesteps = [a.get_timestep() for a in robot_client.action_queue.queue]
assert resulting_timesteps == [5, 6, 7]
@pytest.mark.parametrize(
"weight_old, weight_new",
[
(1.0, 0.0),
(0.0, 1.0),
(0.5, 0.5),
(0.2, 0.8),
(0.8, 0.2),
(0.1, 0.9),
(0.9, 0.1),
],
)
def test_aggregate_action_queues_combines_actions_in_overlap(
robot_client, weight_old: float, weight_new: float
):
"""`_aggregate_action_queues` must combine actions on overlapping timesteps according
to the provided aggregate_fn, here tested with multiple coefficients."""
from lerobot.scripts.server.helpers import TimedAction
robot_client.chunks_received = 0
# Pretend we already executed up to action #4, and queue contains actions for timesteps 5..6
robot_client.latest_action = 4
current_actions = _make_actions(
start_ts=time.time(), start_t=5, count=2
) # actions are [torch.ones(6), torch.ones(6), ...]
current_actions = [
TimedAction(action=10 * a.get_action(), timestep=a.get_timestep(), timestamp=a.get_timestamp())
for a in current_actions
]
for a in current_actions:
robot_client.action_queue.put(a)
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
overlap_timesteps = [5, 6] # properly tested in test_aggregate_action_queues_discards_stale
nonoverlap_timesteps = [7]
robot_client._aggregate_action_queues(
incoming, aggregate_fn=lambda x1, x2: weight_old * x1 + weight_new * x2
)
queue_overlap_actions = []
queue_non_overlap_actions = []
for a in robot_client.action_queue.queue:
if a.get_timestep() in overlap_timesteps:
queue_overlap_actions.append(a)
elif a.get_timestep() in nonoverlap_timesteps:
queue_non_overlap_actions.append(a)
queue_overlap_actions = sorted(queue_overlap_actions, key=lambda x: x.get_timestep())
queue_non_overlap_actions = sorted(queue_non_overlap_actions, key=lambda x: x.get_timestep())
assert torch.allclose(
queue_overlap_actions[0].get_action(),
weight_old * current_actions[0].get_action() + weight_new * incoming[-3].get_action(),
)
assert torch.allclose(
queue_overlap_actions[1].get_action(),
weight_old * current_actions[1].get_action() + weight_new * incoming[-2].get_action(),
)
assert torch.allclose(queue_non_overlap_actions[0].get_action(), incoming[-1].get_action())
@pytest.mark.parametrize(
"chunk_size, queue_len, expected",
[
(20, 12, False), # 12 / 20 = 0.6 > g=0.5 threshold, not ready to send
(20, 8, True), # 8 / 20 = 0.4 <= g=0.5, ready to send
(10, 5, True),
(10, 6, False),
],
)
def test_ready_to_send_observation(robot_client, chunk_size: int, queue_len: int, expected: bool):
"""Validate `_ready_to_send_observation` ratio logic for various sizes."""
robot_client.action_chunk_size = chunk_size
# Clear any existing actions then fill with `queue_len` dummy entries ----
robot_client.action_queue = Queue()
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
for act in dummy_actions:
robot_client.action_queue.put(act)
assert robot_client._ready_to_send_observation() is expected
@pytest.mark.parametrize(
"g_threshold, expected",
[
# The condition is `queue_size / chunk_size <= g`.
# Here, ratio = 6 / 10 = 0.6.
(0.0, False), # 0.6 <= 0.0 is False
(0.1, False),
(0.2, False),
(0.3, False),
(0.4, False),
(0.5, False),
(0.6, True), # 0.6 <= 0.6 is True
(0.7, True),
(0.8, True),
(0.9, True),
(1.0, True),
],
)
def test_ready_to_send_observation_with_varying_threshold(robot_client, g_threshold: float, expected: bool):
"""Validate `_ready_to_send_observation` with fixed sizes and varying `g`."""
# Fixed sizes for this test: ratio = 6 / 10 = 0.6
chunk_size = 10
queue_len = 6
robot_client.action_chunk_size = chunk_size
# This is the parameter we are testing
robot_client._chunk_size_threshold = g_threshold
# Fill queue with dummy actions
robot_client.action_queue = Queue()
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
for act in dummy_actions:
robot_client.action_queue.put(act)
assert robot_client._ready_to_send_observation() is expected