Add Async Inference (#1196)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
ce2b9724bf
commit
30c161006d
@@ -17,6 +17,8 @@
|
||||
title: Train a Robot with RL
|
||||
- local: hilserl_sim
|
||||
title: Train RL in Simulation
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: smolvla
|
||||
|
||||
272
docs/source/async.mdx
Normal file
272
docs/source/async.mdx
Normal file
@@ -0,0 +1,272 @@
|
||||
# Asynchronous Inference
|
||||
|
||||
With our [SmolVLA](https://huggingface.co/papers/2506.01844) we introduced a new way to run inference on real-world robots, **decoupling action prediction from action execution**.
|
||||
In this tutorial, we'll show how to use asynchronous inference (_async inference_) using a finetuned version of SmolVLA, and all the policies supported by LeRobot.
|
||||
**Try async inference with all the policies** supported by LeRobot!
|
||||
|
||||
**What you'll learn:**
|
||||
1. Why asynchronous inference matters and how it compares to, more traditional, sequential inference.
|
||||
2. How to spin-up a `PolicyServer` and connect a `RobotClient` from the same machine, and even over the network.
|
||||
3. How to tune key parameters (`actions_per_chunk`, `chunk_size_threshold`) for your robot and policy.
|
||||
|
||||
If you get stuck, hop into our [Discord community](https://discord.gg/s3KuuzsPFb)!
|
||||
|
||||
|
||||
In a nutshell: with *async inference*, your robot keeps acting while the policy server is already busy computing the next chunk of actions---eliminating "wait-for-inference" lags and unlocking smoother, more reactive behaviours.
|
||||
This is fundamentally different from synchronous inference (sync), where the robot stays idle while the policy computes the next chunk of actions.
|
||||
|
||||
---
|
||||
## Getting started with async inference
|
||||
|
||||
You can read more information on asynchronous inference in our [blogpost](NOTE:blogpost). Here, we report a getting started guide meant to help you setup and run asynchronous inference in your setup.
|
||||
|
||||
First, install `lerobot` with the `async` tag, to install the extra dependencies required to run async inference.
|
||||
|
||||
```shell
|
||||
pip install -e ".[async]"
|
||||
```
|
||||
|
||||
Then, spin up a policy server (in one terminal, or in a separate machine) specifying the host address and port for the client to connect to.
|
||||
You can spin up a policy server running:
|
||||
|
||||
```shell
|
||||
python src/lerobot/scripts/server/policy_server.py \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
```
|
||||
|
||||
This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with:
|
||||
|
||||
```shell
|
||||
python src/lerobot/scripts/server/robot_client.py \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
|
||||
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
|
||||
```
|
||||
In summary, you need to specify instructions for:
|
||||
- `SERVER`: the address and port of the policy server
|
||||
- `ROBOT`: the type of robot to connect to, the port to connect to, and the local `id` of the robot
|
||||
- `POLICY`: the type of policy to run, and the model name/path on server to the checkpoint to run. You also need to specify which device should the sever be using, and how many actions to output at once (capped at the policy max actions value).
|
||||
- `CLIENT`: the threshold for the chunk size before sending a new observation to the server, and the function to aggregate actions on overlapping portions. Optionally, you can also visualize the queue size at runtime, to help you tune the `CLIENT` parameters.
|
||||
|
||||
Importantly,
|
||||
- `actions_per_chunk` and `chunk_size_threshold` are key parameters to tune for your setup.
|
||||
- `aggregate_fn_name` is the function to aggregate actions on overlapping portions. You can either add a new one to a registry of functions, or add your own in `robot_client.py` (see [here](NOTE:addlinktoLOC))
|
||||
- `debug_visualize_queue_size` is a useful tool to tune the `CLIENT` parameters.
|
||||
|
||||
Done! You should see your robot moving around by now 😉
|
||||
---
|
||||
|
||||
## Async vs. synchronous inference
|
||||
|
||||
Synchronous inference relies on interleaving action chunk prediction and action execution. This inherently results in *idle frames*, frames where the robot awaits idle the policy's output: a new action chunk.
|
||||
In turn, inference is plagued by evident real-time lags, where the robot simply stops acting due to the lack of available actions.
|
||||
With robotics models increasing in size, this problem risks becoming only more severe.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/sync.png" width="80%"></img>
|
||||
</p>
|
||||
<p align="center"><i>Synchronous inference</i> makes the robot idle while the policy is computing the next chunk of actions.</p>
|
||||
|
||||
To overcome this, we design async inference, a paradigm where action planning and execution are decoupled, resulting in (1) higher adaptability and, most importantly, (2) no idle frames.
|
||||
Crucially, with async inference, the next action chunk is computed *before* the current one is exhausted, resulting in no idleness.
|
||||
Higher adaptability is ensured by aggregating the different action chunks on overlapping portions, obtaining an up-to-date plan and a tighter control loop.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/async.png" width="80%"></img>
|
||||
</p>
|
||||
<p align="center"><i>Asynchronous inference</i> results in no idleness because the next chunk is computed before the current chunk is exhausted.</p>
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Start the Policy Server
|
||||
|
||||
Policy servers are wrappers around a `PreTrainedPolicy` interfacing them with observations coming from a robot client.
|
||||
Policy servers are initialized as empty containers which are populated with the requested policy specified in the initial handshake between the robot client and the policy server.
|
||||
As such, spinning up a policy server is as easy as specifying the host address and port. If you're running the policy server on the same machine as the robot client, you can use `localhost` as the host address.
|
||||
|
||||
<hfoptions id="start_policy_server">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python -m lerobot.scripts.server.policy_server \
|
||||
--host="localhost" \
|
||||
--port=8080
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
```python
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig
|
||||
from lerobot.scripts.server.policy_server import serve
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host="localhost",
|
||||
port=8080,
|
||||
)
|
||||
serve(config)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
This listens on `localhost:8080` for an incoming connection from the associated`RobotClient`, which will communicate which policy to run during the first client-server handshake.
|
||||
|
||||
---
|
||||
|
||||
## Launch the Robot Client
|
||||
|
||||
`RobotClient` is a wrapper around a `Robot` instance, which `RobotClient` connects to the (possibly remote) `PolicyServer`.
|
||||
The `RobotClient` streams observations to the `PolicyServer`, and receives action chunks obtained running inference on the server (which we assume to have better computational resources than the robot controller).
|
||||
|
||||
<hfoptions id="start_robot_client">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
python src/lerobot/scripts/server/robot_client.py \
|
||||
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
|
||||
--robot.type=so100_follower \ # ROBOT: your robot type
|
||||
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
|
||||
--robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file
|
||||
--robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy
|
||||
--task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act`
|
||||
--policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc)
|
||||
--pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base)
|
||||
--policy_device=mps \ # POLICY: the device to run the policy on, on the server
|
||||
--actions_per_chunk=50 \ # POLICY: the number of actions to output at once
|
||||
--chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server
|
||||
--aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions
|
||||
--debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
```python
|
||||
import threading
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.scripts.server.configs import RobotClientConfig
|
||||
from lerobot.scripts.server.robot_client import RobotClient
|
||||
from lerobot.scripts.server.helpers import visualize_action_queue_size
|
||||
|
||||
# 1. Create the robot instance
|
||||
"""Check out the cameras available in your setup by running `python lerobot/find_cameras.py`"""
|
||||
# these cameras must match the ones expected by the policy
|
||||
# check the config.json on the Hub for the policy you are using
|
||||
camera_cfg = {
|
||||
"top": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem585A0076841",
|
||||
id="follower_so100",
|
||||
cameras=camera_cfg
|
||||
)
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address="localhost:8080",
|
||||
policy_device="mps",
|
||||
policy_type="smolvla",
|
||||
pretrained_name_or_path="fracapuano/smolvla_async",
|
||||
chunk_size_threshold=0.5,
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
# 5. Specify the task
|
||||
task = "Don't do anything, stay still"
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
The following two parameters are key in every setup:
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Hyperparameter</th>
|
||||
<th>Default</th>
|
||||
<th>What it does</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td><code>actions_per_chunk</code></td>
|
||||
<td>50</td>
|
||||
<td>How many actions the policy outputs at once. Typical values: 10-50.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><code>chunk_size_threshold</code></td>
|
||||
<td>0.7</td>
|
||||
<td>When the queue is ≤ 50% full, the client sends a fresh observation. Value in [0, 1].</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<Tip>
|
||||
Different values of `actions_per_chunk` and `chunk_size_threshold` do result in different behaviours.
|
||||
</Tip>
|
||||
|
||||
On the one hand, increasing the value of `actions_per_chunk` will result in reducing the likelihood of ending up with no actions to execute, as more actions will be available when the new chunk is computed.
|
||||
However, larger values of `actions_per_chunk` might also result in less precise actions, due to the compounding errors consequent to predicting actions over longer timespans.
|
||||
|
||||
On the other hand, increasing the value of `chunk_size_threshold` will result in sending out to the `PolicyServer` observations for inference more often, resulting in a larger number of updates action chunks, overlapping on significant portions. This results in high adaptability, in the limit predicting one action chunk for each observation, which is in turn only marginally consumed while a new one is produced.
|
||||
This option does also put more pressure on the inference pipeline, as a consequence of the many requests. Conversely, values of `chunk_size_threshold` close to 0.0 collapse to the synchronous edge case, whereby new observations are only sent out whenever the current chunk is exhausted.
|
||||
|
||||
We found the default values of `actions_per_chunk` and `chunk_size_threshold` to work well in the experiments we developed for the [SmolVLA paper](https://huggingface.co/papers/2506.01844), but recommend experimenting with different values to find the best fit for your setup.
|
||||
|
||||
### Tuning async inference for your setup
|
||||
|
||||
1. **Choose your computational resources carefully.** [PI0](https://huggingface.co/lerobot/pi0) occupies 14GB of memory at inference time, while [SmolVLA](https://huggingface.co/lerobot/smolvla_base) requires only ~2GB. You should identify the best computational resource for your use case keeping in mind smaller policies require less computational resources. The combination of policy and device used (CPU-intensive, using MPS, or the number of CUDA cores on a given NVIDIA GPU) directly impacts the average inference latency you should expect.
|
||||
2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
|
||||
3. **Adjust `chunk_size_threshold`**.
|
||||
- Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
|
||||
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug-visualize-queue-size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/async-inference/queues.png" width="80%"></img>
|
||||
</p>
|
||||
<p align="center"><i>The action queue size is plotted at runtime when the `--debug-visualize-queue-size` flag is passed, for various levels of `chunk_size_threshold` (`g` in the SmolVLA paper).</i></p>
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
Asynchronous inference represents a significant advancement in real-time robotics control, addressing the fundamental challenge of inference latency that has long plagued robotics applications. Through this tutorial, you've learned how to implement a complete async inference pipeline that eliminates idle frames and enables smoother, more reactive robot behaviors.
|
||||
|
||||
**Key Takeaways:**
|
||||
|
||||
- **Paradigm Shift**: Async inference decouples action prediction from execution, allowing robots to continue acting while new action chunks are computed in parallel
|
||||
- **Performance Benefits**: Eliminates "wait-for-inference" lags that are inherent in synchronous approaches, becoming increasingly important as policy models grow larger
|
||||
- **Flexible Architecture**: The server-client design enables distributed computing, where inference can run on powerful remote hardware while maintaining real-time robot control
|
||||
- **Tunable Parameters**: Success depends on properly configuring `actions_per_chunk` and `chunk_size_threshold` for your specific hardware, policy, and task requirements
|
||||
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
|
||||
|
||||
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues).
|
||||
@@ -46,7 +46,7 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"cmake>=3.29.0.1",
|
||||
"datasets>=2.19.0",
|
||||
"datasets>=2.19.0,<=3.6.0",
|
||||
"deepdiff>=7.0.1",
|
||||
"diffusers>=0.27.2",
|
||||
"draccus==0.10.0",
|
||||
@@ -105,6 +105,7 @@ hilserl = ["transformers>=4.50.3", "gym-hil>=0.1.9", "protobuf>=5.29.3", "grpcio
|
||||
umi = ["imagecodecs>=2024.1.1"]
|
||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||
xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"]
|
||||
async = ["grpcio==1.71.0", "matplotlib>=3.10.3"]
|
||||
|
||||
[tool.poetry]
|
||||
requires-poetry = ">=2.1"
|
||||
|
||||
197
src/lerobot/scripts/server/configs.py
Normal file
197
src/lerobot/scripts/server/configs.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.robots.config import RobotConfig
|
||||
from lerobot.scripts.server.constants import (
|
||||
DEFAULT_FPS,
|
||||
DEFAULT_INFERENCE_LATENCY,
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT,
|
||||
)
|
||||
|
||||
# Aggregate function registry for CLI usage
|
||||
AGGREGATE_FUNCTIONS = {
|
||||
"weighted_average": lambda old, new: 0.3 * old + 0.7 * new,
|
||||
"latest_only": lambda old, new: new,
|
||||
"average": lambda old, new: 0.5 * old + 0.5 * new,
|
||||
"conservative": lambda old, new: 0.7 * old + 0.3 * new,
|
||||
}
|
||||
|
||||
|
||||
def get_aggregate_function(name: str) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
||||
"""Get aggregate function by name from registry."""
|
||||
if name not in AGGREGATE_FUNCTIONS:
|
||||
available = list(AGGREGATE_FUNCTIONS.keys())
|
||||
raise ValueError(f"Unknown aggregate function '{name}'. Available: {available}")
|
||||
return AGGREGATE_FUNCTIONS[name]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyServerConfig:
|
||||
"""Configuration for PolicyServer.
|
||||
|
||||
This class defines all configurable parameters for the PolicyServer,
|
||||
including networking settings and action chunking specifications.
|
||||
"""
|
||||
|
||||
# Networking configuration
|
||||
host: str = field(default="localhost", metadata={"help": "Host address to bind the server to"})
|
||||
port: int = field(default=8080, metadata={"help": "Port number to bind the server to"})
|
||||
|
||||
# Timing configuration
|
||||
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
|
||||
inference_latency: float = field(
|
||||
default=DEFAULT_INFERENCE_LATENCY, metadata={"help": "Target inference latency in seconds"}
|
||||
)
|
||||
|
||||
obs_queue_timeout: float = field(
|
||||
default=DEFAULT_OBS_QUEUE_TIMEOUT, metadata={"help": "Timeout for observation queue in seconds"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if self.port < 1 or self.port > 65535:
|
||||
raise ValueError(f"Port must be between 1 and 65535, got {self.port}")
|
||||
|
||||
if self.environment_dt <= 0:
|
||||
raise ValueError(f"environment_dt must be positive, got {self.environment_dt}")
|
||||
|
||||
if self.inference_latency < 0:
|
||||
raise ValueError(f"inference_latency must be non-negative, got {self.inference_latency}")
|
||||
|
||||
if self.obs_queue_timeout < 0:
|
||||
raise ValueError(f"obs_queue_timeout must be non-negative, got {self.obs_queue_timeout}")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict) -> "PolicyServerConfig":
|
||||
"""Create a PolicyServerConfig from a dictionary."""
|
||||
return cls(**config_dict)
|
||||
|
||||
@property
|
||||
def environment_dt(self) -> float:
|
||||
"""Environment time step, in seconds"""
|
||||
return 1 / self.fps
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the configuration to a dictionary."""
|
||||
return {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"fps": self.fps,
|
||||
"environment_dt": self.environment_dt,
|
||||
"inference_latency": self.inference_latency,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RobotClientConfig:
|
||||
"""Configuration for RobotClient.
|
||||
|
||||
This class defines all configurable parameters for the RobotClient,
|
||||
including network connection, policy settings, and control behavior.
|
||||
"""
|
||||
|
||||
# Policy configuration
|
||||
policy_type: str = field(metadata={"help": "Type of policy to use"})
|
||||
pretrained_name_or_path: str = field(metadata={"help": "Pretrained model name or path"})
|
||||
|
||||
# Robot configuration (for CLI usage - robot instance will be created from this)
|
||||
robot: RobotConfig = field(metadata={"help": "Robot configuration"})
|
||||
|
||||
# Policies typically output K actions at max, but we can use less to avoid wasting bandwidth (as actions
|
||||
# would be aggregated on the client side anyway, depending on the value of `chunk_size_threshold`)
|
||||
actions_per_chunk: int = field(metadata={"help": "Number of actions per chunk"})
|
||||
|
||||
# Task instruction for the robot to execute (e.g., 'fold my tshirt')
|
||||
task: str = field(default="", metadata={"help": "Task instruction for the robot to execute"})
|
||||
|
||||
# Network configuration
|
||||
server_address: str = field(default="localhost:8080", metadata={"help": "Server address to connect to"})
|
||||
|
||||
# Device configuration
|
||||
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
|
||||
|
||||
# Control behavior configuration
|
||||
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
|
||||
fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"})
|
||||
|
||||
# Aggregate function configuration (CLI-compatible)
|
||||
aggregate_fn_name: str = field(
|
||||
default="weighted_average",
|
||||
metadata={"help": f"Name of aggregate function to use. Options: {list(AGGREGATE_FUNCTIONS.keys())}"},
|
||||
)
|
||||
|
||||
# Debug configuration
|
||||
debug_visualize_queue_size: bool = field(
|
||||
default=False, metadata={"help": "Visualize the action queue size"}
|
||||
)
|
||||
|
||||
# Verification configuration
|
||||
verify_robot_cameras: bool = field(
|
||||
default=True, metadata={"help": "Verify that the robot cameras match the policy cameras"}
|
||||
)
|
||||
|
||||
@property
|
||||
def environment_dt(self) -> float:
|
||||
"""Environment time step, in seconds"""
|
||||
return 1 / self.fps
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if not self.server_address:
|
||||
raise ValueError("server_address cannot be empty")
|
||||
|
||||
if not self.policy_type:
|
||||
raise ValueError("policy_type cannot be empty")
|
||||
|
||||
if not self.pretrained_name_or_path:
|
||||
raise ValueError("pretrained_name_or_path cannot be empty")
|
||||
|
||||
if not self.policy_device:
|
||||
raise ValueError("policy_device cannot be empty")
|
||||
|
||||
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
|
||||
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
|
||||
|
||||
if self.fps <= 0:
|
||||
raise ValueError(f"fps must be positive, got {self.fps}")
|
||||
|
||||
if self.actions_per_chunk <= 0:
|
||||
raise ValueError(f"actions_per_chunk must be positive, got {self.actions_per_chunk}")
|
||||
|
||||
self.aggregate_fn = get_aggregate_function(self.aggregate_fn_name)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict) -> "RobotClientConfig":
|
||||
"""Create a RobotClientConfig from a dictionary."""
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the configuration to a dictionary."""
|
||||
return {
|
||||
"server_address": self.server_address,
|
||||
"policy_type": self.policy_type,
|
||||
"pretrained_name_or_path": self.pretrained_name_or_path,
|
||||
"policy_device": self.policy_device,
|
||||
"chunk_size_threshold": self.chunk_size_threshold,
|
||||
"fps": self.fps,
|
||||
"actions_per_chunk": self.actions_per_chunk,
|
||||
"task": self.task,
|
||||
"debug_visualize_queue_size": self.debug_visualize_queue_size,
|
||||
"aggregate_fn_name": self.aggregate_fn_name,
|
||||
}
|
||||
29
src/lerobot/scripts/server/constants.py
Normal file
29
src/lerobot/scripts/server/constants.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Client side: The environment evolves with a time resolution equal to 1/fps"""
|
||||
|
||||
DEFAULT_FPS = 30
|
||||
|
||||
"""Server side: Running inference on (at most) 1/fps"""
|
||||
DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
|
||||
|
||||
"""Server side: Timeout for observation queue in seconds"""
|
||||
DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||
|
||||
# All action chunking policies
|
||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"]
|
||||
|
||||
# TODO: Add all other robots
|
||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"]
|
||||
386
src/lerobot/scripts/server/helpers.py
Normal file
386
src/lerobot/scripts/server/helpers.py
Normal file
@@ -0,0 +1,386 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from threading import Event
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
|
||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.transport import async_inference_pb2
|
||||
from lerobot.transport.utils import bytes_buffer_size
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
Action = torch.Tensor
|
||||
ActionChunk = torch.Tensor
|
||||
|
||||
# observation as received from the robot
|
||||
RawObservation = dict[str, torch.Tensor]
|
||||
|
||||
# observation as those recorded in LeRobot dataset (keys are different)
|
||||
LeRobotObservation = dict[str, torch.Tensor]
|
||||
|
||||
# observation, ready for policy inference (image keys resized)
|
||||
Observation = dict[str, torch.Tensor]
|
||||
|
||||
|
||||
def visualize_action_queue_size(action_queue_size: list[int]) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_title("Action Queue Size Over Time")
|
||||
ax.set_xlabel("Environment steps")
|
||||
ax.set_ylabel("Action Queue Size")
|
||||
ax.set_ylim(0, max(action_queue_size) * 1.1)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.plot(range(len(action_queue_size)), action_queue_size)
|
||||
plt.show()
|
||||
|
||||
|
||||
def validate_robot_cameras_for_policy(
|
||||
lerobot_observation_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature]
|
||||
) -> None:
|
||||
image_keys = list(filter(is_image_key, lerobot_observation_features))
|
||||
assert set(image_keys) == set(policy_image_features.keys()), (
|
||||
f"Policy image features must match robot cameras! Received {list(policy_image_features.keys())} != {image_keys}"
|
||||
)
|
||||
|
||||
|
||||
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
|
||||
return hw_to_dataset_features(robot.observation_features, "observation", use_video=False)
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int, int, int]) -> torch.tensor:
|
||||
assert image.ndim == 3, f"Image must be (C, H, W)! Received {image.shape}"
|
||||
# (H, W, C) -> (C, H, W) for resizing from robot obsevation resolution to policy image resolution
|
||||
image = image.permute(2, 0, 1)
|
||||
dims = (resize_dims[1], resize_dims[2])
|
||||
# Add batch dimension for interpolate: (C, H, W) -> (1, C, H, W)
|
||||
image_batched = image.unsqueeze(0)
|
||||
# Interpolate and remove batch dimension: (1, C, H, W) -> (C, H, W)
|
||||
resized = torch.nn.functional.interpolate(image_batched, size=dims, mode="bilinear", align_corners=False)
|
||||
|
||||
return resized.squeeze(0)
|
||||
|
||||
|
||||
def raw_observation_to_observation(
|
||||
raw_observation: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
policy_image_features: dict[str, PolicyFeature],
|
||||
device: str,
|
||||
) -> Observation:
|
||||
observation = {}
|
||||
|
||||
observation = prepare_raw_observation(raw_observation, lerobot_features, policy_image_features)
|
||||
for k, v in observation.items():
|
||||
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations
|
||||
if "image" in k:
|
||||
# Policy expects images in shape (B, C, H, W)
|
||||
observation[k] = prepare_image(v).unsqueeze(0).to(device)
|
||||
else:
|
||||
observation[k] = v.to(device)
|
||||
else:
|
||||
observation[k] = v
|
||||
|
||||
return observation
|
||||
|
||||
|
||||
def prepare_image(image: torch.Tensor) -> torch.Tensor:
|
||||
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
image = image.type(torch.float32) / 255
|
||||
image = image.contiguous()
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def extract_state_from_raw_observation(
|
||||
lerobot_obs: RawObservation,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the state from a raw observation."""
|
||||
state = torch.tensor(lerobot_obs[OBS_STATE])
|
||||
|
||||
if state.ndim == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def extract_images_from_raw_observation(
|
||||
lerobot_obs: RawObservation,
|
||||
camera_key: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Extract the images from a raw observation."""
|
||||
return torch.tensor(lerobot_obs[camera_key])
|
||||
|
||||
|
||||
def make_lerobot_observation(
|
||||
robot_obs: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
) -> LeRobotObservation:
|
||||
"""Make a lerobot observation from a raw observation."""
|
||||
return build_dataset_frame(lerobot_features, robot_obs, prefix="observation")
|
||||
|
||||
|
||||
def prepare_raw_observation(
|
||||
robot_obs: RawObservation,
|
||||
lerobot_features: dict[str, dict],
|
||||
policy_image_features: dict[str, PolicyFeature],
|
||||
) -> Observation:
|
||||
"""Matches keys from the raw robot_obs dict to the keys expected by a given policy (passed as
|
||||
policy_image_features)."""
|
||||
# 1. {motor.pos1:value1, motor.pos2:value2, ..., laptop:np.ndarray} ->
|
||||
# -> {observation.state:[value1,value2,...], observation.images.laptop:np.ndarray}
|
||||
lerobot_obs = make_lerobot_observation(robot_obs, lerobot_features)
|
||||
|
||||
# 2. Greps all observation.images.<> keys
|
||||
image_keys = list(filter(is_image_key, lerobot_obs))
|
||||
# state's shape is expected as (B, state_dim)
|
||||
state_dict = {OBS_STATE: extract_state_from_raw_observation(lerobot_obs)}
|
||||
image_dict = {
|
||||
image_k: extract_images_from_raw_observation(lerobot_obs, image_k) for image_k in image_keys
|
||||
}
|
||||
|
||||
# Turns the image features to (C, H, W) with H, W matching the policy image features.
|
||||
# This reduces the resolution of the images
|
||||
image_dict = {
|
||||
key: resize_robot_observation_image(torch.tensor(lerobot_obs[key]), policy_image_features[key].shape)
|
||||
for key in image_keys
|
||||
}
|
||||
|
||||
if "task" in robot_obs:
|
||||
state_dict["task"] = robot_obs["task"]
|
||||
|
||||
return {**state_dict, **image_dict}
|
||||
|
||||
|
||||
def get_logger(name: str, log_to_file: bool = True) -> logging.Logger:
|
||||
"""
|
||||
Get a logger using the standardized logging setup from utils.py.
|
||||
|
||||
Args:
|
||||
name: Logger name (e.g., 'policy_server', 'robot_client')
|
||||
log_to_file: Whether to also log to a file
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
# Create logs directory if logging to file
|
||||
if log_to_file:
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
log_file = Path(f"logs/{name}_{int(time.time())}.log")
|
||||
else:
|
||||
log_file = None
|
||||
|
||||
# Initialize the standardized logging
|
||||
init_logging(log_file=log_file, display_pid=False)
|
||||
|
||||
# Return a named logger
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedData:
|
||||
"""A data object with timestamp and timestep information.
|
||||
|
||||
Args:
|
||||
timestamp: Unix timestamp relative to data's creation.
|
||||
data: The actual data to wrap a timestamp around.
|
||||
timestep: The timestep of the data.
|
||||
"""
|
||||
|
||||
timestamp: float
|
||||
timestep: int
|
||||
|
||||
def get_timestamp(self):
|
||||
return self.timestamp
|
||||
|
||||
def get_timestep(self):
|
||||
return self.timestep
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedAction(TimedData):
|
||||
action: Action
|
||||
|
||||
def get_action(self):
|
||||
return self.action
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimedObservation(TimedData):
|
||||
observation: RawObservation
|
||||
must_go: bool = False
|
||||
|
||||
def get_observation(self):
|
||||
return self.observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class FPSTracker:
|
||||
"""Utility class to track FPS metrics over time."""
|
||||
|
||||
target_fps: float
|
||||
first_timestamp: float = None
|
||||
total_obs_count: int = 0
|
||||
|
||||
def calculate_fps_metrics(self, current_timestamp: float) -> dict[str, float]:
|
||||
"""Calculate average FPS vs target"""
|
||||
self.total_obs_count += 1
|
||||
|
||||
# Initialize first observation time
|
||||
if self.first_timestamp is None:
|
||||
self.first_timestamp = current_timestamp
|
||||
|
||||
# Calculate overall average FPS (since start)
|
||||
total_duration = current_timestamp - self.first_timestamp
|
||||
avg_fps = (self.total_obs_count - 1) / total_duration if total_duration > 1e-6 else 0.0
|
||||
|
||||
return {"avg_fps": avg_fps, "target_fps": self.target_fps}
|
||||
|
||||
def reset(self):
|
||||
"""Reset the FPS tracker state"""
|
||||
self.first_timestamp = None
|
||||
self.total_obs_count = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemotePolicyConfig:
|
||||
policy_type: str
|
||||
pretrained_name_or_path: str
|
||||
lerobot_features: dict[str, PolicyFeature]
|
||||
actions_per_chunk: int
|
||||
device: str = "cpu"
|
||||
|
||||
|
||||
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:
|
||||
"""Check if two observation states are similar, under a tolerance threshold"""
|
||||
return bool(torch.linalg.norm(obs1_state - obs2_state) < atol)
|
||||
|
||||
|
||||
def observations_similar(
|
||||
obs1: TimedObservation, obs2: TimedObservation, lerobot_features: dict[str, dict], atol: float = 1
|
||||
) -> bool:
|
||||
"""Check if two observations are similar, under a tolerance threshold. Measures distance between
|
||||
observations as the difference in joint-space between the two observations.
|
||||
|
||||
NOTE(fracapuano): This is a very simple check, and it is enough for the current use case.
|
||||
An immediate next step is to use (fast) perceptual difference metrics comparing some camera views,
|
||||
to surpass this joint-space similarity check.
|
||||
"""
|
||||
obs1_state = extract_state_from_raw_observation(
|
||||
make_lerobot_observation(obs1.get_observation(), lerobot_features)
|
||||
)
|
||||
obs2_state = extract_state_from_raw_observation(
|
||||
make_lerobot_observation(obs2.get_observation(), lerobot_features)
|
||||
)
|
||||
|
||||
return _compare_observation_states(obs1_state, obs2_state, atol=atol)
|
||||
|
||||
|
||||
def send_bytes_in_chunks(
|
||||
buffer: bytes,
|
||||
message_class: Any,
|
||||
log_prefix: str = "",
|
||||
silent: bool = True,
|
||||
chunk_size: int = 3 * 1024 * 1024,
|
||||
):
|
||||
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.send_bytes_in_chunks. Duplication can't be avoided if we
|
||||
# don't use a unique class for messages sent (due to the different transfer states sent). Also, I'd want more control over the
|
||||
# chunk size as I am using it to send image observations.
|
||||
buffer = io.BytesIO(buffer)
|
||||
size_in_bytes = bytes_buffer_size(buffer)
|
||||
|
||||
sent_bytes = 0
|
||||
|
||||
logging_method = logging.info if not silent else logging.debug
|
||||
|
||||
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
|
||||
|
||||
while sent_bytes < size_in_bytes:
|
||||
transfer_state = async_inference_pb2.TransferState.TRANSFER_MIDDLE
|
||||
|
||||
if sent_bytes + chunk_size >= size_in_bytes:
|
||||
transfer_state = async_inference_pb2.TransferState.TRANSFER_END
|
||||
elif sent_bytes == 0:
|
||||
transfer_state = async_inference_pb2.TransferState.TRANSFER_BEGIN
|
||||
|
||||
size_to_read = min(chunk_size, size_in_bytes - sent_bytes)
|
||||
chunk = buffer.read(size_to_read)
|
||||
|
||||
yield message_class(transfer_state=transfer_state, data=chunk)
|
||||
sent_bytes += size_to_read
|
||||
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
|
||||
|
||||
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
||||
|
||||
|
||||
def receive_bytes_in_chunks(
|
||||
iterator, continue_receiving: Event, logger: logging.Logger, log_prefix: str = ""
|
||||
): # type: ignore
|
||||
# NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.receive_bytes_in_chunks. Duplication can't be avoided if we
|
||||
# don't use a unique class for messages sent (due to the different transfer states sent). Also, on the server side the logic for receiving
|
||||
# is opposite then the HIL-SERL design (my event showcases keeping on running instead of shutdown)
|
||||
bytes_buffer = io.BytesIO()
|
||||
step = 0
|
||||
|
||||
logger.info(f"{log_prefix} Starting receiver")
|
||||
for item in iterator:
|
||||
logger.debug(f"{log_prefix} Received item")
|
||||
if not continue_receiving.is_set():
|
||||
logger.info(f"{log_prefix} Shutting down receiver")
|
||||
return
|
||||
|
||||
if item.transfer_state == async_inference_pb2.TransferState.TRANSFER_BEGIN:
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
bytes_buffer.write(item.data)
|
||||
logger.debug(f"{log_prefix} Received data at step 0")
|
||||
|
||||
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_MIDDLE:
|
||||
bytes_buffer.write(item.data)
|
||||
step += 1
|
||||
logger.debug(f"{log_prefix} Received data at step {step}")
|
||||
|
||||
elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_END:
|
||||
bytes_buffer.write(item.data)
|
||||
logger.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
||||
|
||||
complete_bytes = bytes_buffer.getvalue()
|
||||
|
||||
bytes_buffer.seek(0)
|
||||
bytes_buffer.truncate(0)
|
||||
|
||||
logger.debug(f"{log_prefix} Queue updated")
|
||||
return complete_bytes
|
||||
|
||||
else:
|
||||
logger.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}")
|
||||
raise ValueError(f"Received unknown transfer state {item.transfer_state}")
|
||||
403
src/lerobot/scripts/server/policy_server.py
Normal file
403
src/lerobot/scripts/server/policy_server.py
Normal file
@@ -0,0 +1,403 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example:
|
||||
```shell
|
||||
python src/lerobot/scripts/server/policy_server.py \
|
||||
--host=127.0.0.1 \
|
||||
--port=8080 \
|
||||
--fps=30 \
|
||||
--inference_latency=0.033 \
|
||||
--obs_queue_timeout=1
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Empty, Queue
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig
|
||||
from lerobot.scripts.server.constants import SUPPORTED_POLICIES
|
||||
from lerobot.scripts.server.helpers import (
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RemotePolicyConfig,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
observations_similar,
|
||||
raw_observation_to_observation,
|
||||
receive_bytes_in_chunks,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
async_inference_pb2, # type: ignore
|
||||
async_inference_pb2_grpc, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
|
||||
prefix = "policy_server"
|
||||
logger = get_logger(prefix)
|
||||
|
||||
def __init__(self, config: PolicyServerConfig):
|
||||
self.config = config
|
||||
self._running_event = threading.Event()
|
||||
|
||||
# FPS measurement
|
||||
self.fps_tracker = FPSTracker(target_fps=config.fps)
|
||||
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
|
||||
self._predicted_timesteps_lock = threading.Lock()
|
||||
self._predicted_timesteps = set()
|
||||
|
||||
self.last_processed_obs = None
|
||||
|
||||
# Attributes will be set by SendPolicyInstructions
|
||||
self.device = None
|
||||
self.policy_type = None
|
||||
self.lerobot_features = None
|
||||
self.actions_per_chunk = None
|
||||
self.policy = None
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return self._running_event.is_set()
|
||||
|
||||
@property
|
||||
def policy_image_features(self):
|
||||
return self.policy.config.image_features
|
||||
|
||||
def _reset_server(self) -> None:
|
||||
"""Flushes server state when new client connects."""
|
||||
# only running inference on the latest observation received by the server
|
||||
self._running_event.clear()
|
||||
self.observation_queue = Queue(maxsize=1)
|
||||
|
||||
with self._predicted_timesteps_lock:
|
||||
self._predicted_timesteps = set()
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
client_id = context.peer()
|
||||
self.logger.info(f"Client {client_id} connected and ready")
|
||||
self._reset_server()
|
||||
self._running_event.set()
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def SendPolicyInstructions(self, request, context): # noqa: N802
|
||||
"""Receive policy instructions from the robot client"""
|
||||
|
||||
if not self.running:
|
||||
self.logger.warning("Server is not running. Ignoring policy instructions.")
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
client_id = context.peer()
|
||||
|
||||
policy_specs = pickle.loads(request.data) # nosec
|
||||
|
||||
if not isinstance(policy_specs, RemotePolicyConfig):
|
||||
raise TypeError(f"Policy specs must be a RemotePolicyConfig. Got {type(policy_specs)}")
|
||||
|
||||
if policy_specs.policy_type not in SUPPORTED_POLICIES:
|
||||
raise ValueError(
|
||||
f"Policy type {policy_specs.policy_type} not supported. "
|
||||
f"Supported policies: {SUPPORTED_POLICIES}"
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"Receiving policy instructions from {client_id} | "
|
||||
f"Policy type: {policy_specs.policy_type} | "
|
||||
f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | "
|
||||
f"Actions per chunk: {policy_specs.actions_per_chunk} | "
|
||||
f"Device: {policy_specs.device}"
|
||||
)
|
||||
|
||||
self.device = policy_specs.device
|
||||
self.policy_type = policy_specs.policy_type # act, pi0, etc.
|
||||
self.lerobot_features = policy_specs.lerobot_features
|
||||
self.actions_per_chunk = policy_specs.actions_per_chunk
|
||||
|
||||
policy_class = get_policy_class(self.policy_type)
|
||||
|
||||
start = time.perf_counter()
|
||||
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
|
||||
self.policy.to(self.device)
|
||||
end = time.perf_counter()
|
||||
|
||||
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def SendObservations(self, request_iterator, context): # noqa: N802
|
||||
"""Receive observations from the robot client"""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Receiving observations from {client_id}")
|
||||
|
||||
receive_time = time.time() # comparing timestamps so need time.time()
|
||||
start_deserialize = time.perf_counter()
|
||||
received_bytes = receive_bytes_in_chunks(
|
||||
request_iterator, self._running_event, self.logger
|
||||
) # blocking call while looping over request_iterator
|
||||
timed_observation = pickle.loads(received_bytes) # nosec
|
||||
deserialize_time = time.perf_counter() - start_deserialize
|
||||
|
||||
self.logger.debug(f"Received observation #{timed_observation.get_timestep()}")
|
||||
|
||||
obs_timestep = timed_observation.get_timestep()
|
||||
obs_timestamp = timed_observation.get_timestamp()
|
||||
|
||||
# Calculate FPS metrics
|
||||
fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp)
|
||||
|
||||
self.logger.info(
|
||||
f"Received observation #{obs_timestep} | "
|
||||
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client
|
||||
f"Target: {fps_metrics['target_fps']:.2f} | "
|
||||
f"One-way latency: {(receive_time - obs_timestamp) * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Server timestamp: {receive_time:.6f} | "
|
||||
f"Client timestamp: {obs_timestamp:.6f} | "
|
||||
f"Deserialization time: {deserialize_time:.6f}s"
|
||||
)
|
||||
|
||||
if not self._enqueue_observation(
|
||||
timed_observation # wrapping a RawObservation
|
||||
):
|
||||
self.logger.info(f"Observation #{obs_timestep} has been filtered out")
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def GetActions(self, request, context): # noqa: N802
|
||||
"""Returns actions to the robot client. Actions are sent as a single
|
||||
chunk, containing multiple actions."""
|
||||
client_id = context.peer()
|
||||
self.logger.debug(f"Client {client_id} connected for action streaming")
|
||||
|
||||
# Generate action based on the most recent observation and its timestep
|
||||
try:
|
||||
getactions_starts = time.perf_counter()
|
||||
obs = self.observation_queue.get(timeout=self.config.obs_queue_timeout)
|
||||
self.logger.info(
|
||||
f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})"
|
||||
)
|
||||
|
||||
with self._predicted_timesteps_lock:
|
||||
self._predicted_timesteps.add(obs.get_timestep())
|
||||
|
||||
start_time = time.perf_counter()
|
||||
action_chunk = self._predict_action_chunk(obs)
|
||||
inference_time = time.perf_counter() - start_time
|
||||
|
||||
start_time = time.perf_counter()
|
||||
actions_bytes = pickle.dumps(action_chunk) # nosec
|
||||
serialize_time = time.perf_counter() - start_time
|
||||
|
||||
# Create and return the action chunk
|
||||
actions = async_inference_pb2.Actions(data=actions_bytes)
|
||||
|
||||
self.logger.info(
|
||||
f"Action chunk #{obs.get_timestep()} generated | "
|
||||
f"Total time: {(inference_time + serialize_time) * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Action chunk #{obs.get_timestep()} generated | "
|
||||
f"Inference time: {inference_time:.2f}s |"
|
||||
f"Serialize time: {serialize_time:.2f}s |"
|
||||
f"Total time: {inference_time + serialize_time:.2f}s"
|
||||
)
|
||||
|
||||
time.sleep(
|
||||
max(0, self.config.inference_latency - max(0, time.perf_counter() - getactions_starts))
|
||||
) # sleep controls inference latency
|
||||
|
||||
return actions
|
||||
|
||||
except Empty: # no observation added to queue in obs_queue_timeout
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in StreamActions: {e}")
|
||||
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool:
|
||||
"""Check if the observation is valid to be processed by the policy"""
|
||||
with self._predicted_timesteps_lock:
|
||||
predicted_timesteps = self._predicted_timesteps
|
||||
|
||||
if obs.get_timestep() in predicted_timesteps:
|
||||
self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!")
|
||||
return False
|
||||
|
||||
elif observations_similar(obs, previous_obs, lerobot_features=self.lerobot_features):
|
||||
self.logger.debug(
|
||||
f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!"
|
||||
)
|
||||
return False
|
||||
|
||||
else:
|
||||
return True
|
||||
|
||||
def _enqueue_observation(self, obs: TimedObservation) -> bool:
|
||||
"""Enqueue an observation if it must go through processing, otherwise skip it.
|
||||
Observations not in queue are never run through the policy network"""
|
||||
|
||||
if (
|
||||
obs.must_go
|
||||
or self.last_processed_obs is None
|
||||
or self._obs_sanity_checks(obs, self.last_processed_obs)
|
||||
):
|
||||
last_obs = self.last_processed_obs.get_timestep() if self.last_processed_obs else "None"
|
||||
self.logger.debug(
|
||||
f"Enqueuing observation. Must go: {obs.must_go} | Last processed obs: {last_obs}"
|
||||
)
|
||||
|
||||
# If queue is full, get the old observation to make room
|
||||
if self.observation_queue.full():
|
||||
# pops from queue
|
||||
_ = self.observation_queue.get_nowait()
|
||||
self.logger.debug("Observation queue was full, removed oldest observation")
|
||||
|
||||
# Now put the new observation (never blocks as queue is non-full here)
|
||||
self.observation_queue.put(obs)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]:
|
||||
"""Turn a chunk of actions into a list of TimedAction instances,
|
||||
with the first action corresponding to t_0 and the rest corresponding to
|
||||
t_0 + i*environment_dt for i in range(len(action_chunk))
|
||||
"""
|
||||
return [
|
||||
TimedAction(timestamp=t_0 + i * self.config.environment_dt, timestep=i_0 + i, action=action)
|
||||
for i, action in enumerate(action_chunk)
|
||||
]
|
||||
|
||||
def _prepare_observation(self, observation_t: TimedObservation) -> Observation:
|
||||
"""
|
||||
Prepare observation, ready for policy inference.
|
||||
E.g.: To keep observation sampling rate high (and network packet tiny) we send int8 [0,255] images from the
|
||||
client and then convert them to float32 [0,1] images here, before running inference.
|
||||
"""
|
||||
# RawObservation from robot.get_observation() - wrong keys, wrong dtype, wrong image shape
|
||||
observation: Observation = raw_observation_to_observation(
|
||||
observation_t.get_observation(),
|
||||
self.lerobot_features,
|
||||
self.policy_image_features,
|
||||
self.device,
|
||||
)
|
||||
# processed Observation - right keys, right dtype, right image shape
|
||||
|
||||
return observation
|
||||
|
||||
def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Get an action chunk from the policy. The chunk contains only"""
|
||||
chunk = self.policy.predict_action_chunk(observation)
|
||||
if chunk.ndim != 3:
|
||||
chunk = chunk.unsqueeze(0) # adding batch dimension, now shape is (B, chunk_size, action_dim)
|
||||
|
||||
return chunk[:, : self.actions_per_chunk, :] + torch.randn_like(chunk[:, : self.actions_per_chunk, :])
|
||||
|
||||
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
|
||||
"""Predict an action chunk based on an observation"""
|
||||
inference_starts = time.perf_counter()
|
||||
|
||||
"""1. Prepare observation"""
|
||||
start_time = time.perf_counter()
|
||||
observation = self._prepare_observation(observation_t)
|
||||
preprocessing_time = time.perf_counter() - start_time
|
||||
|
||||
self.last_processed_obs: TimedObservation = observation_t
|
||||
|
||||
"""2. Get action chunk"""
|
||||
start_time = time.perf_counter()
|
||||
action_tensor = self._get_action_chunk(observation)
|
||||
inference_time = time.perf_counter() - start_time
|
||||
|
||||
"""3. Post-inference processing"""
|
||||
start_time = time.perf_counter()
|
||||
# Move to CPU before serializing
|
||||
action_tensor = action_tensor.cpu().squeeze(0)
|
||||
|
||||
action_chunk = self._time_action_chunk(
|
||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||
)
|
||||
postprocessing_time = time.perf_counter() - start_time
|
||||
inference_stops = time.perf_counter()
|
||||
|
||||
self.logger.info(
|
||||
f"Observation {observation_t.get_timestep()} |"
|
||||
f"Inference time: {1000 * (inference_stops - inference_starts):.2f}ms"
|
||||
)
|
||||
|
||||
# full-process latency breakdown for debugging purposes
|
||||
self.logger.debug(
|
||||
f"Observation {observation_t.get_timestep()} | "
|
||||
f"Preprocessing time: {1000 * (preprocessing_time - inference_starts):.2f}ms | "
|
||||
f"Inference time: {1000 * (inference_time - preprocessing_time):.2f}ms | "
|
||||
f"Postprocessing time: {1000 * (postprocessing_time - inference_time):.2f}ms | "
|
||||
f"Total time: {1000 * (postprocessing_time - inference_starts):.2f}ms"
|
||||
)
|
||||
|
||||
return action_chunk
|
||||
|
||||
def stop(self):
|
||||
"""Stop the server"""
|
||||
self._reset_server()
|
||||
self.logger.info("Server stopping...")
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def serve(cfg: PolicyServerConfig):
|
||||
"""Start the PolicyServer with the given configuration.
|
||||
|
||||
Args:
|
||||
config: PolicyServerConfig instance. If None, uses default configuration.
|
||||
"""
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Create the server instance first
|
||||
policy_server = PolicyServer(cfg)
|
||||
|
||||
# Setup and start gRPC server
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||||
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
server.add_insecure_port(f"{cfg.host}:{cfg.port}")
|
||||
|
||||
policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}")
|
||||
server.start()
|
||||
|
||||
server.wait_for_termination()
|
||||
|
||||
policy_server.logger.info("Server terminated")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
||||
509
src/lerobot/scripts/server/robot_client.py
Normal file
509
src/lerobot/scripts/server/robot_client.py
Normal file
@@ -0,0 +1,509 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example command:
|
||||
```shell
|
||||
python src/lerobot/scripts/server/robot_client.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
||||
--robot.id=black \
|
||||
--task="dummy" \
|
||||
--server_address=127.0.0.1:8080 \
|
||||
--policy_type=act \
|
||||
--pretrained_name_or_path=user/model \
|
||||
--policy_device=mps \
|
||||
--actions_per_chunk=50 \
|
||||
--chunk_size_threshold=0.5 \
|
||||
--aggregate_fn_name=weighted_average \
|
||||
--debug_visualize_queue_size=True
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pickle # nosec
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Queue
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.scripts.server.configs import RobotClientConfig
|
||||
from lerobot.scripts.server.constants import SUPPORTED_ROBOTS
|
||||
from lerobot.scripts.server.helpers import (
|
||||
Action,
|
||||
FPSTracker,
|
||||
Observation,
|
||||
RawObservation,
|
||||
RemotePolicyConfig,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
get_logger,
|
||||
map_robot_keys_to_lerobot_features,
|
||||
send_bytes_in_chunks,
|
||||
validate_robot_cameras_for_policy,
|
||||
visualize_action_queue_size,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
async_inference_pb2, # type: ignore
|
||||
async_inference_pb2_grpc, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class RobotClient:
|
||||
prefix = "robot_client"
|
||||
logger = get_logger(prefix)
|
||||
|
||||
def __init__(self, config: RobotClientConfig):
|
||||
"""Initialize RobotClient with unified configuration.
|
||||
|
||||
Args:
|
||||
config: RobotClientConfig containing all configuration parameters
|
||||
"""
|
||||
# Store configuration
|
||||
self.config = config
|
||||
self.robot = make_robot_from_config(config.robot)
|
||||
self.robot.connect()
|
||||
|
||||
lerobot_features = map_robot_keys_to_lerobot_features(self.robot)
|
||||
|
||||
if config.verify_robot_cameras:
|
||||
# Load policy config for validation
|
||||
policy_config = PreTrainedConfig.from_pretrained(config.pretrained_name_or_path)
|
||||
policy_image_features = policy_config.image_features
|
||||
|
||||
# The cameras specified for inference must match the one supported by the policy chosen
|
||||
validate_robot_cameras_for_policy(lerobot_features, policy_image_features)
|
||||
|
||||
# Use environment variable if server_address is not provided in config
|
||||
self.server_address = config.server_address
|
||||
|
||||
self.policy_config = RemotePolicyConfig(
|
||||
config.policy_type,
|
||||
config.pretrained_name_or_path,
|
||||
lerobot_features,
|
||||
config.actions_per_chunk,
|
||||
config.policy_device,
|
||||
)
|
||||
self.channel = grpc.insecure_channel(self.server_address)
|
||||
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
|
||||
|
||||
self._running_event = threading.Event()
|
||||
|
||||
# Initialize client side variables
|
||||
self.latest_action_lock = threading.Lock()
|
||||
self.latest_action = -1
|
||||
self.action_chunk_size = -1
|
||||
|
||||
self._chunk_size_threshold = config.chunk_size_threshold
|
||||
|
||||
self.action_queue = Queue()
|
||||
self.action_queue_lock = threading.Lock() # Protect queue operations
|
||||
self.action_queue_size = []
|
||||
self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop
|
||||
|
||||
# FPS measurement
|
||||
self.fps_tracker = FPSTracker(target_fps=self.config.fps)
|
||||
|
||||
self.logger.info("Robot connected and ready")
|
||||
|
||||
# Use an event for thread-safe coordination
|
||||
self.must_go = threading.Event()
|
||||
self.must_go.set() # Initially set - observations qualify for direct processing
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return self._running_event.is_set()
|
||||
|
||||
def start(self):
|
||||
"""Start the robot client and connect to the policy server"""
|
||||
try:
|
||||
# client-server handshake
|
||||
start_time = time.perf_counter()
|
||||
self.stub.Ready(async_inference_pb2.Empty())
|
||||
end_time = time.perf_counter()
|
||||
self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s")
|
||||
|
||||
# send policy instructions
|
||||
policy_config_bytes = pickle.dumps(self.policy_config)
|
||||
policy_setup = async_inference_pb2.PolicySetup(data=policy_config_bytes)
|
||||
|
||||
self.logger.info("Sending policy instructions to policy server")
|
||||
self.logger.debug(
|
||||
f"Policy type: {self.policy_config.policy_type} | "
|
||||
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
|
||||
f"Device: {self.policy_config.device}"
|
||||
)
|
||||
|
||||
self.stub.SendPolicyInstructions(policy_setup)
|
||||
|
||||
self._running_event.set()
|
||||
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Failed to connect to policy server: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""Stop the robot client"""
|
||||
self._running_event.clear()
|
||||
|
||||
self.robot.disconnect()
|
||||
self.logger.debug("Robot disconnected")
|
||||
|
||||
self.channel.close()
|
||||
self.logger.debug("Client stopped, channel closed")
|
||||
|
||||
def send_observation(
|
||||
self,
|
||||
obs: TimedObservation,
|
||||
) -> bool:
|
||||
"""Send observation to the policy server.
|
||||
Returns True if the observation was sent successfully, False otherwise."""
|
||||
if not self.running:
|
||||
raise RuntimeError("Client not running. Run RobotClient.start() before sending observations.")
|
||||
|
||||
if not isinstance(obs, TimedObservation):
|
||||
raise ValueError("Input observation needs to be a TimedObservation!")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
observation_bytes = pickle.dumps(obs)
|
||||
serialize_time = time.perf_counter() - start_time
|
||||
self.logger.debug(f"Observation serialization time: {serialize_time:.6f}s")
|
||||
|
||||
try:
|
||||
observation_iterator = send_bytes_in_chunks(
|
||||
observation_bytes,
|
||||
async_inference_pb2.Observation,
|
||||
log_prefix="[CLIENT] Observation",
|
||||
silent=True,
|
||||
)
|
||||
_ = self.stub.SendObservations(observation_iterator)
|
||||
obs_timestep = obs.get_timestep()
|
||||
self.logger.info(f"Sent observation #{obs_timestep} | ")
|
||||
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
|
||||
return False
|
||||
|
||||
def _inspect_action_queue(self):
|
||||
with self.action_queue_lock:
|
||||
queue_size = self.action_queue.qsize()
|
||||
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
|
||||
self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
|
||||
return queue_size, timestamps
|
||||
|
||||
def _aggregate_action_queues(
|
||||
self,
|
||||
incoming_actions: list[TimedAction],
|
||||
aggregate_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
):
|
||||
"""Finds the same timestep actions in the queue and aggregates them using the aggregate_fn"""
|
||||
if aggregate_fn is None:
|
||||
# default aggregate function: take the latest action
|
||||
def aggregate_fn(x1, x2):
|
||||
return x2
|
||||
|
||||
future_action_queue = Queue()
|
||||
with self.action_queue_lock:
|
||||
internal_queue = self.action_queue.queue
|
||||
|
||||
current_action_queue = {action.get_timestep(): action.get_action() for action in internal_queue}
|
||||
|
||||
for new_action in incoming_actions:
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
# New action is older than the latest action in the queue, skip it
|
||||
if new_action.get_timestep() <= latest_action:
|
||||
continue
|
||||
|
||||
# If the new action's timestep is not in the current action queue, add it directly
|
||||
elif new_action.get_timestep() not in current_action_queue:
|
||||
future_action_queue.put(new_action)
|
||||
continue
|
||||
|
||||
# If the new action's timestep is in the current action queue, aggregate it
|
||||
# TODO: There is probably a way to do this with broadcasting of the two action tensors
|
||||
future_action_queue.put(
|
||||
TimedAction(
|
||||
timestamp=new_action.get_timestamp(),
|
||||
timestep=new_action.get_timestep(),
|
||||
action=aggregate_fn(
|
||||
current_action_queue[new_action.get_timestep()], new_action.get_action()
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with self.action_queue_lock:
|
||||
self.action_queue = future_action_queue
|
||||
|
||||
def receive_actions(self, verbose: bool = False):
|
||||
"""Receive actions from the policy server"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Action receiving thread starting")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Use StreamActions to get a stream of actions from the server
|
||||
actions_chunk = self.stub.GetActions(async_inference_pb2.Empty())
|
||||
if len(actions_chunk.data) == 0:
|
||||
continue # received `Empty` from server, wait for next call
|
||||
|
||||
receive_time = time.time()
|
||||
|
||||
# Deserialize bytes back into list[TimedAction]
|
||||
deserialize_start = time.perf_counter()
|
||||
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
||||
deserialize_time = time.perf_counter() - deserialize_start
|
||||
|
||||
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
|
||||
|
||||
# Calculate network latency if we have matching observations
|
||||
if len(timed_actions) > 0 and verbose:
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
self.logger.debug(f"Current latest action: {latest_action}")
|
||||
|
||||
# Get queue state before changes
|
||||
old_size, old_timesteps = self._inspect_action_queue()
|
||||
if not old_timesteps:
|
||||
old_timesteps = [latest_action] # queue was empty
|
||||
|
||||
# Get queue state before changes
|
||||
old_size, old_timesteps = self._inspect_action_queue()
|
||||
if not old_timesteps:
|
||||
old_timesteps = [latest_action] # queue was empty
|
||||
|
||||
# Log incoming actions
|
||||
incoming_timesteps = [a.get_timestep() for a in timed_actions]
|
||||
|
||||
first_action_timestep = timed_actions[0].get_timestep()
|
||||
server_to_client_latency = (receive_time - timed_actions[0].get_timestamp()) * 1000
|
||||
|
||||
self.logger.info(
|
||||
f"Received action chunk for step #{first_action_timestep} | "
|
||||
f"Latest action: #{latest_action} | "
|
||||
f"Incoming actions: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
|
||||
f"Network latency (server->client): {server_to_client_latency:.2f}ms | "
|
||||
f"Deserialization time: {deserialize_time * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
# Update action queue
|
||||
start_time = time.perf_counter()
|
||||
self._aggregate_action_queues(timed_actions, self.config.aggregate_fn)
|
||||
queue_update_time = time.perf_counter() - start_time
|
||||
|
||||
self.must_go.set() # after receiving actions, next empty queue triggers must-go processing!
|
||||
|
||||
if verbose:
|
||||
# Get queue state after changes
|
||||
new_size, new_timesteps = self._inspect_action_queue()
|
||||
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
self.logger.info(
|
||||
f"Latest action: {latest_action} | "
|
||||
f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | "
|
||||
f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | "
|
||||
f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}"
|
||||
)
|
||||
self.logger.debug(
|
||||
f"Queue update complete ({queue_update_time:.6f}s) | "
|
||||
f"Before: {old_size} items | "
|
||||
f"After: {new_size} items | "
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
self.logger.error(f"Error receiving actions: {e}")
|
||||
|
||||
def actions_available(self):
|
||||
"""Check if there are actions available in the queue"""
|
||||
with self.action_queue_lock:
|
||||
return not self.action_queue.empty()
|
||||
|
||||
def _action_tensor_to_action_dict(self, action_tensor: torch.Tensor) -> dict[str, float]:
|
||||
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
|
||||
return action
|
||||
|
||||
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
|
||||
"""Reading and performing actions in local queue"""
|
||||
|
||||
# Lock only for queue operations
|
||||
get_start = time.perf_counter()
|
||||
with self.action_queue_lock:
|
||||
self.action_queue_size.append(self.action_queue.qsize())
|
||||
# Get action from queue
|
||||
timed_action = self.action_queue.get_nowait()
|
||||
get_end = time.perf_counter() - get_start
|
||||
|
||||
_performed_action = self.robot.send_action(
|
||||
self._action_tensor_to_action_dict(timed_action.get_action())
|
||||
)
|
||||
with self.latest_action_lock:
|
||||
self.latest_action = timed_action.get_timestep()
|
||||
|
||||
if verbose:
|
||||
with self.action_queue_lock:
|
||||
current_queue_size = self.action_queue.qsize()
|
||||
|
||||
self.logger.debug(
|
||||
f"Ts={timed_action.get_timestamp()} | "
|
||||
f"Action #{timed_action.get_timestep()} performed | "
|
||||
f"Queue size: {current_queue_size}"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Popping action from queue to perform took {get_end:.6f}s | Queue size: {current_queue_size}"
|
||||
)
|
||||
|
||||
return _performed_action
|
||||
|
||||
def _ready_to_send_observation(self):
|
||||
"""Flags when the client is ready to send an observation"""
|
||||
with self.action_queue_lock:
|
||||
return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold
|
||||
|
||||
def control_loop_observation(self, task: str, verbose: bool = False) -> RawObservation:
|
||||
try:
|
||||
# Get serialized observation bytes from the function
|
||||
start_time = time.perf_counter()
|
||||
|
||||
raw_observation: RawObservation = self.robot.get_observation()
|
||||
raw_observation["task"] = task
|
||||
|
||||
with self.latest_action_lock:
|
||||
latest_action = self.latest_action
|
||||
|
||||
observation = TimedObservation(
|
||||
timestamp=time.time(), # need time.time() to compare timestamps across client and server
|
||||
observation=raw_observation,
|
||||
timestep=max(latest_action, 0),
|
||||
)
|
||||
|
||||
obs_capture_time = time.perf_counter() - start_time
|
||||
|
||||
# If there are no actions left in the queue, the observation must go through processing!
|
||||
with self.action_queue_lock:
|
||||
observation.must_go = self.must_go.is_set() and self.action_queue.empty()
|
||||
current_queue_size = self.action_queue.qsize()
|
||||
|
||||
_ = self.send_observation(observation)
|
||||
|
||||
self.logger.debug(f"QUEUE SIZE: {current_queue_size} (Must go: {observation.must_go})")
|
||||
if observation.must_go:
|
||||
# must-go event will be set again after receiving actions
|
||||
self.must_go.clear()
|
||||
|
||||
if verbose:
|
||||
# Calculate comprehensive FPS metrics
|
||||
fps_metrics = self.fps_tracker.calculate_fps_metrics(observation.get_timestamp())
|
||||
|
||||
self.logger.info(
|
||||
f"Obs #{observation.get_timestep()} | "
|
||||
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | "
|
||||
f"Target: {fps_metrics['target_fps']:.2f}"
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Ts={observation.get_timestamp():.6f} | Capturing observation took {obs_capture_time:.6f}s"
|
||||
)
|
||||
|
||||
return raw_observation
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in observation sender: {e}")
|
||||
|
||||
def control_loop(self, task: str, verbose: bool = False) -> tuple[Observation, Action]:
|
||||
"""Combined function for executing actions and streaming observations"""
|
||||
# Wait at barrier for synchronized start
|
||||
self.start_barrier.wait()
|
||||
self.logger.info("Control loop thread starting")
|
||||
|
||||
_performed_action = None
|
||||
_captured_observation = None
|
||||
|
||||
while self.running:
|
||||
control_loop_start = time.perf_counter()
|
||||
"""Control loop: (1) Performing actions, when available"""
|
||||
if self.actions_available():
|
||||
_performed_action = self.control_loop_action(verbose)
|
||||
|
||||
"""Control loop: (2) Streaming observations to the remote policy server"""
|
||||
if self._ready_to_send_observation():
|
||||
_captured_observation = self.control_loop_observation(task, verbose)
|
||||
|
||||
self.logger.info(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
|
||||
# Dynamically adjust sleep time to maintain the desired control frequency
|
||||
time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start)))
|
||||
|
||||
return _captured_observation, _performed_action
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def async_client(cfg: RobotClientConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
|
||||
client = RobotClient(cfg)
|
||||
|
||||
if client.start():
|
||||
client.logger.info("Starting action receiver thread...")
|
||||
|
||||
# Create and start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
|
||||
# Start action receiver thread
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# The main thread runs the control loop
|
||||
client.control_loop(task=cfg.task)
|
||||
|
||||
finally:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
if cfg.debug_visualize_queue_size:
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
client.logger.info("Client stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
async_client() # run the client
|
||||
59
src/lerobot/transport/async_inference.proto
Normal file
59
src/lerobot/transport/async_inference.proto
Normal file
@@ -0,0 +1,59 @@
|
||||
// fmt: off
|
||||
// flake8: noqa
|
||||
// !/usr/bin/env python
|
||||
|
||||
// Copyright 2024 The HuggingFace Inc. team.
|
||||
// All rights reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
syntax = "proto3";
|
||||
|
||||
package async_inference;
|
||||
|
||||
// AsyncInference: from Robot perspective
|
||||
// Robot send observations to & executes action received from a remote Policy server
|
||||
service AsyncInference {
|
||||
// Robot -> Policy to share observations with a remote inference server
|
||||
// Policy -> Robot to share actions predicted for given observations
|
||||
rpc SendObservations(stream Observation) returns (Empty);
|
||||
rpc GetActions(Empty) returns (Actions);
|
||||
rpc SendPolicyInstructions(PolicySetup) returns (Empty);
|
||||
rpc Ready(Empty) returns (Empty);
|
||||
rpc Stop(Empty) returns (Empty);
|
||||
}
|
||||
|
||||
enum TransferState {
|
||||
TRANSFER_UNKNOWN = 0;
|
||||
TRANSFER_BEGIN = 1;
|
||||
TRANSFER_MIDDLE = 2;
|
||||
TRANSFER_END = 3;
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Observation {
|
||||
// sent by Robot, to remote Policy
|
||||
TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
message Actions {
|
||||
// sent by remote Policy, to Robot
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message PolicySetup {
|
||||
// sent by Robot to remote server, to init Policy
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
45
src/lerobot/transport/async_inference_pb2.py
Normal file
45
src/lerobot/transport/async_inference_pb2.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: async_inference.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'async_inference.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdd\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12>\n\nGetActions\x12\x16.async_inference.Empty\x1a\x18.async_inference.Actions\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Empty\x12\x36\n\x04Stop\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_TRANSFERSTATE']._serialized_start=190
|
||||
_globals['_TRANSFERSTATE']._serialized_end=286
|
||||
_globals['_OBSERVATION']._serialized_start=42
|
||||
_globals['_OBSERVATION']._serialized_end=125
|
||||
_globals['_ACTIONS']._serialized_start=127
|
||||
_globals['_ACTIONS']._serialized_end=150
|
||||
_globals['_POLICYSETUP']._serialized_start=152
|
||||
_globals['_POLICYSETUP']._serialized_end=179
|
||||
_globals['_EMPTY']._serialized_start=181
|
||||
_globals['_EMPTY']._serialized_end=188
|
||||
_globals['_ASYNCINFERENCE']._serialized_start=289
|
||||
_globals['_ASYNCINFERENCE']._serialized_end=638
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
277
src/lerobot/transport/async_inference_pb2_grpc.py
Normal file
277
src/lerobot/transport/async_inference_pb2_grpc.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from lerobot.transport import async_inference_pb2 as async__inference__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.71.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in async_inference_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class AsyncInferenceStub:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendObservations = channel.stream_unary(
|
||||
'/async_inference.AsyncInference/SendObservations',
|
||||
request_serializer=async__inference__pb2.Observation.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.GetActions = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/GetActions',
|
||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Actions.FromString,
|
||||
_registered_method=True)
|
||||
self.SendPolicyInstructions = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/SendPolicyInstructions',
|
||||
request_serializer=async__inference__pb2.PolicySetup.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Ready = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/Ready',
|
||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.Stop = channel.unary_unary(
|
||||
'/async_inference.AsyncInference/Stop',
|
||||
request_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
response_deserializer=async__inference__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class AsyncInferenceServicer:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
def SendObservations(self, request_iterator, context):
|
||||
"""Robot -> Policy to share observations with a remote inference server
|
||||
Policy -> Robot to share actions predicted for given observations
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetActions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendPolicyInstructions(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Ready(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Stop(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_AsyncInferenceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendObservations': grpc.stream_unary_rpc_method_handler(
|
||||
servicer.SendObservations,
|
||||
request_deserializer=async__inference__pb2.Observation.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'GetActions': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetActions,
|
||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
||||
response_serializer=async__inference__pb2.Actions.SerializeToString,
|
||||
),
|
||||
'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendPolicyInstructions,
|
||||
request_deserializer=async__inference__pb2.PolicySetup.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Ready': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Ready,
|
||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'Stop': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Stop,
|
||||
request_deserializer=async__inference__pb2.Empty.FromString,
|
||||
response_serializer=async__inference__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'async_inference.AsyncInference', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class AsyncInference:
|
||||
"""AsyncInference: from Robot perspective
|
||||
Robot send observations to & executes action received from a remote Policy server
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def SendObservations(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_unary(
|
||||
request_iterator,
|
||||
target,
|
||||
'/async_inference.AsyncInference/SendObservations',
|
||||
async__inference__pb2.Observation.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetActions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/GetActions',
|
||||
async__inference__pb2.Empty.SerializeToString,
|
||||
async__inference__pb2.Actions.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendPolicyInstructions(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/SendPolicyInstructions',
|
||||
async__inference__pb2.PolicySetup.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Ready(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/Ready',
|
||||
async__inference__pb2.Empty.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Stop(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/async_inference.AsyncInference/Stop',
|
||||
async__inference__pb2.Empty.SerializeToString,
|
||||
async__inference__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
177
tests/async_inference/test_e2e.py
Normal file
177
tests/async_inference/test_e2e.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""End-to-end test of the asynchronous inference stack (client ↔ server).
|
||||
|
||||
This test spins up a lightweight gRPC `PolicyServer` instance with a stubbed
|
||||
policy network and launches a `RobotClient` that uses a `MockRobot`. The goal
|
||||
is to exercise the full communication loop:
|
||||
|
||||
1. Client sends policy specification → Server
|
||||
2. Client streams observations → Server
|
||||
3. Server streams action chunks → Client
|
||||
4. Client executes received actions
|
||||
|
||||
The test succeeds if at least one action is executed and the server records at
|
||||
least one predicted timestep - demonstrating that the gRPC round-trip works
|
||||
end-to-end using real (but lightweight) protocol messages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from concurrent import futures
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip entire module if grpc is not available
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# End-to-end test
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_async_inference_e2e(monkeypatch):
|
||||
"""Tests the full asynchronous inference pipeline."""
|
||||
# Import grpc-dependent modules inside the test function
|
||||
import grpc
|
||||
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig, RobotClientConfig
|
||||
from lerobot.scripts.server.helpers import map_robot_keys_to_lerobot_features
|
||||
from lerobot.scripts.server.policy_server import PolicyServer
|
||||
from lerobot.scripts.server.robot_client import RobotClient
|
||||
from lerobot.transport import (
|
||||
async_inference_pb2, # type: ignore
|
||||
async_inference_pb2_grpc, # type: ignore
|
||||
)
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
|
||||
# Create a stub policy similar to test_policy_server.py
|
||||
class MockPolicy:
|
||||
"""A minimal mock for an actual policy, returning zeros."""
|
||||
|
||||
class _Config:
|
||||
robot_type = "dummy_robot"
|
||||
|
||||
@property
|
||||
def image_features(self):
|
||||
"""Empty image features since this test doesn't use images."""
|
||||
return {}
|
||||
|
||||
def __init__(self):
|
||||
self.config = self._Config()
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def model(self, batch):
|
||||
# Return a chunk of 20 dummy actions.
|
||||
batch_size = len(batch["robot_type"])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Create PolicyServer instance with mock policy
|
||||
# ------------------------------------------------------------------
|
||||
policy_server_config = PolicyServerConfig(host="localhost", port=9999)
|
||||
policy_server = PolicyServer(policy_server_config)
|
||||
# Replace the real policy with our fast, deterministic stub.
|
||||
policy_server.policy = MockPolicy()
|
||||
policy_server.actions_per_chunk = 20
|
||||
policy_server.device = "cpu"
|
||||
|
||||
# Set up robot config and features
|
||||
robot_config = MockRobotConfig()
|
||||
mock_robot = make_robot_from_config(robot_config)
|
||||
|
||||
lerobot_features = map_robot_keys_to_lerobot_features(mock_robot)
|
||||
policy_server.lerobot_features = lerobot_features
|
||||
|
||||
# Force server to produce deterministic action chunks in test mode
|
||||
policy_server.policy_type = "act"
|
||||
|
||||
def _fake_get_action_chunk(_self, _obs, _type="test"):
|
||||
action_dim = 6
|
||||
batch_size = 1
|
||||
actions_per_chunk = policy_server.actions_per_chunk
|
||||
|
||||
return torch.zeros(batch_size, actions_per_chunk, action_dim)
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
|
||||
|
||||
# Bypass potentially heavy model loading inside SendPolicyInstructions
|
||||
def _fake_send_policy_instructions(self, request, context): # noqa: N802
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True)
|
||||
|
||||
# Build gRPC server running a PolicyServer
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server"))
|
||||
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
|
||||
# Use the host/port specified in the fixture's config
|
||||
server_address = f"{policy_server.config.host}:{policy_server.config.port}"
|
||||
server.add_insecure_port(server_address)
|
||||
server.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Create a RobotClient around the MockRobot
|
||||
# ------------------------------------------------------------------
|
||||
client_config = RobotClientConfig(
|
||||
server_address=server_address,
|
||||
robot=robot_config,
|
||||
chunk_size_threshold=0.0,
|
||||
policy_type="test",
|
||||
pretrained_name_or_path="test",
|
||||
actions_per_chunk=20,
|
||||
verify_robot_cameras=False,
|
||||
)
|
||||
|
||||
client = RobotClient(client_config)
|
||||
assert client.start(), "Client failed initial handshake with the server"
|
||||
|
||||
# Track action chunks received without modifying RobotClient
|
||||
action_chunks_received = {"count": 0}
|
||||
original_aggregate = client._aggregate_action_queues
|
||||
|
||||
def counting_aggregate(*args, **kwargs):
|
||||
action_chunks_received["count"] += 1
|
||||
return original_aggregate(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)
|
||||
|
||||
# Start client threads
|
||||
action_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
control_thread = threading.Thread(target=client.control_loop, args=({"task": ""}), daemon=True)
|
||||
action_thread.start()
|
||||
control_thread.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. System exchanges a few messages
|
||||
# ------------------------------------------------------------------
|
||||
# Wait for 5 seconds
|
||||
server.wait_for_termination(timeout=5)
|
||||
|
||||
assert action_chunks_received["count"] > 0, "Client did not receive any action chunks"
|
||||
assert len(policy_server._predicted_timesteps) > 0, "Server did not record any predicted timesteps"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Stop the system
|
||||
# ------------------------------------------------------------------
|
||||
client.stop()
|
||||
action_thread.join()
|
||||
control_thread.join()
|
||||
policy_server.stop()
|
||||
server.stop(grace=None)
|
||||
459
tests/async_inference/test_helpers.py
Normal file
459
tests/async_inference/test_helpers.py
Normal file
@@ -0,0 +1,459 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import pickle
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.scripts.server.helpers import (
|
||||
FPSTracker,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
observations_similar,
|
||||
prepare_image,
|
||||
prepare_raw_observation,
|
||||
raw_observation_to_observation,
|
||||
resize_robot_observation_image,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# FPSTracker
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_fps_tracker_first_observation():
|
||||
"""First observation should initialize timestamp and return 0 FPS."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
timestamp = 1000.0
|
||||
|
||||
metrics = tracker.calculate_fps_metrics(timestamp)
|
||||
|
||||
assert tracker.first_timestamp == timestamp
|
||||
assert tracker.total_obs_count == 1
|
||||
assert metrics["avg_fps"] == 0.0
|
||||
assert metrics["target_fps"] == 30.0
|
||||
|
||||
|
||||
def test_fps_tracker_single_interval():
|
||||
"""Two observations 1 second apart should give 1 FPS."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
|
||||
# First observation at t=0
|
||||
metrics1 = tracker.calculate_fps_metrics(0.0)
|
||||
assert metrics1["avg_fps"] == 0.0
|
||||
|
||||
# Second observation at t=1 (1 second later)
|
||||
metrics2 = tracker.calculate_fps_metrics(1.0)
|
||||
expected_fps = 1.0 # (2-1) observations / 1.0 seconds = 1 FPS
|
||||
assert math.isclose(metrics2["avg_fps"], expected_fps, rel_tol=1e-6)
|
||||
|
||||
|
||||
def test_fps_tracker_multiple_intervals():
|
||||
"""Multiple observations should calculate correct average FPS."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
|
||||
# Simulate 5 observations over 2 seconds (should be 2 FPS average)
|
||||
timestamps = [0.0, 0.5, 1.0, 1.5, 2.0]
|
||||
|
||||
for i, ts in enumerate(timestamps):
|
||||
metrics = tracker.calculate_fps_metrics(ts)
|
||||
|
||||
if i == 0:
|
||||
assert metrics["avg_fps"] == 0.0
|
||||
elif i == len(timestamps) - 1:
|
||||
# After 5 observations over 2 seconds: (5-1)/2 = 2 FPS
|
||||
expected_fps = 2.0
|
||||
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
|
||||
|
||||
|
||||
def test_fps_tracker_irregular_intervals():
|
||||
"""FPS calculation should work with irregular time intervals."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
|
||||
# Irregular timestamps: 0, 0.1, 0.5, 2.0, 3.0 seconds
|
||||
timestamps = [0.0, 0.1, 0.5, 2.0, 3.0]
|
||||
|
||||
for ts in timestamps:
|
||||
metrics = tracker.calculate_fps_metrics(ts)
|
||||
|
||||
# 5 observations over 3 seconds: (5-1)/3 = 1.333... FPS
|
||||
expected_fps = 4.0 / 3.0
|
||||
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# TimedData helpers
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_timed_action_getters():
|
||||
"""TimedAction stores & returns timestamp, action tensor and timestep."""
|
||||
ts = time.time()
|
||||
action = torch.arange(10)
|
||||
ta = TimedAction(timestamp=ts, action=action, timestep=0)
|
||||
|
||||
assert math.isclose(ta.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
torch.testing.assert_close(ta.get_action(), action)
|
||||
assert ta.get_timestep() == 0
|
||||
|
||||
|
||||
def test_timed_observation_getters():
|
||||
"""TimedObservation stores & returns timestamp, dict and timestep."""
|
||||
ts = time.time()
|
||||
obs_dict = {"observation.state": torch.ones(6)}
|
||||
to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0)
|
||||
|
||||
assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
assert to.get_observation() is obs_dict
|
||||
assert to.get_timestep() == 0
|
||||
|
||||
|
||||
def test_timed_data_deserialization_data_getters():
|
||||
"""TimedAction / TimedObservation survive a round-trip through ``pickle``.
|
||||
|
||||
The async-inference stack uses ``pickle.dumps`` to move these objects across
|
||||
the gRPC boundary (see RobotClient.send_observation and PolicyServer.StreamActions).
|
||||
This test ensures that the payload keeps its content intact after
|
||||
the (de)serialization round-trip.
|
||||
"""
|
||||
ts = time.time()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TimedAction
|
||||
# ------------------------------------------------------------------
|
||||
original_action = torch.randn(6)
|
||||
ta_in = TimedAction(timestamp=ts, action=original_action, timestep=13)
|
||||
|
||||
# Serialize → bytes → deserialize
|
||||
ta_bytes = pickle.dumps(ta_in) # nosec
|
||||
ta_out: TimedAction = pickle.loads(ta_bytes) # nosec B301
|
||||
|
||||
# Identity & content checks
|
||||
assert math.isclose(ta_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
assert ta_out.get_timestep() == 13
|
||||
torch.testing.assert_close(ta_out.get_action(), original_action)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TimedObservation
|
||||
# ------------------------------------------------------------------
|
||||
obs_dict = {"observation.state": torch.arange(4).float()}
|
||||
to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True)
|
||||
|
||||
to_bytes = pickle.dumps(to_in) # nosec
|
||||
to_out: TimedObservation = pickle.loads(to_bytes) # nosec B301
|
||||
|
||||
assert math.isclose(to_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
assert to_out.get_timestep() == 7
|
||||
assert to_out.must_go is True
|
||||
assert to_out.get_observation().keys() == obs_dict.keys()
|
||||
torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# observations_similar()
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_obs(state: torch.Tensor) -> TimedObservation:
|
||||
"""Create a TimedObservation with raw robot observation format."""
|
||||
return TimedObservation(
|
||||
timestamp=time.time(),
|
||||
observation={
|
||||
"shoulder": state[0].item() if len(state) > 0 else 0.0,
|
||||
"elbow": state[1].item() if len(state) > 1 else 0.0,
|
||||
"wrist": state[2].item() if len(state) > 2 else 0.0,
|
||||
"gripper": state[3].item() if len(state) > 3 else 0.0,
|
||||
},
|
||||
timestep=0,
|
||||
)
|
||||
|
||||
|
||||
def test_observations_similar_true():
|
||||
"""Distance below atol → observations considered similar."""
|
||||
# Create mock lerobot features for the similarity check
|
||||
lerobot_features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
}
|
||||
}
|
||||
|
||||
obs1 = _make_obs(torch.zeros(4))
|
||||
obs2 = _make_obs(0.5 * torch.ones(4))
|
||||
assert observations_similar(obs1, obs2, lerobot_features, atol=2.0)
|
||||
|
||||
obs3 = _make_obs(2.0 * torch.ones(4))
|
||||
assert not observations_similar(obs1, obs3, lerobot_features, atol=2.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# raw_observation_to_observation and helpers
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_mock_robot_observation():
|
||||
"""Create a mock robot observation with motor positions and camera images."""
|
||||
return {
|
||||
"shoulder": 1.0,
|
||||
"elbow": 2.0,
|
||||
"wrist": 3.0,
|
||||
"gripper": 0.5,
|
||||
"laptop": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
|
||||
"phone": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
|
||||
}
|
||||
|
||||
|
||||
def _create_mock_lerobot_features():
|
||||
"""Create mock lerobot features mapping similar to what hw_to_dataset_features returns."""
|
||||
return {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
},
|
||||
"observation.images.laptop": {
|
||||
"dtype": "image",
|
||||
"shape": [480, 640, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"observation.images.phone": {
|
||||
"dtype": "image",
|
||||
"shape": [480, 640, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _create_mock_policy_image_features():
|
||||
"""Create mock policy image features with different resolutions."""
|
||||
return {
|
||||
"observation.images.laptop": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224), # Policy expects smaller resolution
|
||||
),
|
||||
"observation.images.phone": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 160, 160), # Different resolution for second camera
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_prepare_image():
|
||||
"""Test image preprocessing: int8 → float32, normalization to [0,1]."""
|
||||
# Create mock int8 image data
|
||||
image_int8 = torch.randint(0, 256, size=(3, 224, 224), dtype=torch.uint8)
|
||||
|
||||
processed = prepare_image(image_int8)
|
||||
|
||||
# Check dtype conversion
|
||||
assert processed.dtype == torch.float32
|
||||
|
||||
# Check normalization range
|
||||
assert processed.min() >= 0.0
|
||||
assert processed.max() <= 1.0
|
||||
|
||||
# Check that values are scaled correctly (255 → 1.0, 0 → 0.0)
|
||||
if image_int8.max() == 255:
|
||||
assert torch.isclose(processed.max(), torch.tensor(1.0), atol=1e-6)
|
||||
if image_int8.min() == 0:
|
||||
assert torch.isclose(processed.min(), torch.tensor(0.0), atol=1e-6)
|
||||
|
||||
# Check memory contiguity
|
||||
assert processed.is_contiguous()
|
||||
|
||||
|
||||
def test_resize_robot_observation_image():
|
||||
"""Test image resizing from robot resolution to policy resolution."""
|
||||
# Create mock image: (H=480, W=640, C=3)
|
||||
original_image = torch.randint(0, 256, size=(480, 640, 3), dtype=torch.uint8)
|
||||
target_shape = (3, 224, 224) # (C, H, W)
|
||||
|
||||
resized = resize_robot_observation_image(original_image, target_shape)
|
||||
|
||||
# Check output shape matches target
|
||||
assert resized.shape == target_shape
|
||||
|
||||
# Check that original image had different dimensions
|
||||
assert original_image.shape != resized.shape
|
||||
|
||||
# Check that resizing preserves value range
|
||||
assert resized.min() >= 0
|
||||
assert resized.max() <= 255
|
||||
|
||||
|
||||
def test_prepare_raw_observation():
|
||||
"""Test the preparation of raw robot observation to lerobot format."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that state is properly extracted and batched
|
||||
assert "observation.state" in prepared
|
||||
state = prepared["observation.state"]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.shape == (1, 4) # Batched state
|
||||
|
||||
# Check that images are processed and resized
|
||||
assert "observation.images.laptop" in prepared
|
||||
assert "observation.images.phone" in prepared
|
||||
|
||||
laptop_img = prepared["observation.images.laptop"]
|
||||
phone_img = prepared["observation.images.phone"]
|
||||
|
||||
# Check image shapes match policy requirements
|
||||
assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape
|
||||
assert phone_img.shape == policy_image_features["observation.images.phone"].shape
|
||||
|
||||
# Check that images are tensors
|
||||
assert isinstance(laptop_img, torch.Tensor)
|
||||
assert isinstance(phone_img, torch.Tensor)
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_basic():
|
||||
"""Test the main raw_observation_to_observation function."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
device = "cpu"
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
|
||||
# Check that all expected keys are present
|
||||
assert "observation.state" in observation
|
||||
assert "observation.images.laptop" in observation
|
||||
assert "observation.images.phone" in observation
|
||||
|
||||
# Check state processing
|
||||
state = observation["observation.state"]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.device.type == device
|
||||
assert state.shape == (1, 4) # Batched
|
||||
|
||||
# Check image processing
|
||||
laptop_img = observation["observation.images.laptop"]
|
||||
phone_img = observation["observation.images.phone"]
|
||||
|
||||
# Images should have batch dimension: (B, C, H, W)
|
||||
assert laptop_img.shape == (1, 3, 224, 224)
|
||||
assert phone_img.shape == (1, 3, 160, 160)
|
||||
|
||||
# Check device placement
|
||||
assert laptop_img.device.type == device
|
||||
assert phone_img.device.type == device
|
||||
|
||||
# Check image dtype and range (should be float32 in [0, 1])
|
||||
assert laptop_img.dtype == torch.float32
|
||||
assert phone_img.dtype == torch.float32
|
||||
assert laptop_img.min() >= 0.0 and laptop_img.max() <= 1.0
|
||||
assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_with_non_tensor_data():
|
||||
"""Test that non-tensor data (like task strings) is preserved."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
robot_obs["task"] = "pick up the red cube" # Add string instruction
|
||||
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
device = "cpu"
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
|
||||
# Check that task string is preserved
|
||||
assert "task" in observation
|
||||
assert observation["task"] == "pick up the red cube"
|
||||
assert isinstance(observation["task"], str)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_raw_observation_to_observation_device_handling():
|
||||
"""Test that tensors are properly moved to the specified device."""
|
||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
|
||||
# Check that all tensors are on the correct device
|
||||
for key, value in observation.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert value.device.type == device, f"Tensor {key} not on {device}"
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_deterministic():
|
||||
"""Test that the function produces consistent results for the same input."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
device = "cpu"
|
||||
|
||||
# Run twice with same input
|
||||
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
|
||||
# Results should be identical
|
||||
assert set(obs1.keys()) == set(obs2.keys())
|
||||
|
||||
for key in obs1:
|
||||
if isinstance(obs1[key], torch.Tensor):
|
||||
torch.testing.assert_close(obs1[key], obs2[key])
|
||||
else:
|
||||
assert obs1[key] == obs2[key]
|
||||
|
||||
|
||||
def test_image_processing_pipeline_preserves_content():
|
||||
"""Test that the image processing pipeline preserves recognizable patterns."""
|
||||
# Create an image with a specific pattern
|
||||
original_img = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
original_img[25:75, 25:75, :] = 255 # White square in center
|
||||
|
||||
robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img}
|
||||
lerobot_features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
},
|
||||
"observation.images.laptop": {
|
||||
"dtype": "image",
|
||||
"shape": [100, 100, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
}
|
||||
policy_image_features = {
|
||||
"observation.images.laptop": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 50, 50), # Downsamples from 100x100
|
||||
)
|
||||
}
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu")
|
||||
|
||||
processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim
|
||||
|
||||
# Check that the center region has higher values than corners
|
||||
# Due to bilinear interpolation, exact values will change but pattern should remain
|
||||
center_val = processed_img[:, 25, 25].mean() # Center of 50x50 image
|
||||
corner_val = processed_img[:, 5, 5].mean() # Corner
|
||||
|
||||
assert center_val > corner_val, "Image processing should preserve recognizable patterns"
|
||||
215
tests/async_inference/test_policy_server.py
Normal file
215
tests/async_inference/test_policy_server.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit-tests for the `PolicyServer` core logic.
|
||||
Monkey-patch the `policy` attribute with a stub so that no real model inference is performed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from tests.utils import require_package
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Test fixtures
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockPolicy:
|
||||
"""A minimal mock for an actual policy, returning zeros.
|
||||
Refer to tests/policies for tests of the individual policies supported."""
|
||||
|
||||
class _Config:
|
||||
robot_type = "dummy_robot"
|
||||
|
||||
@property
|
||||
def image_features(self) -> dict[str, PolicyFeature]:
|
||||
"""Empty image features since this test doesn't use images."""
|
||||
return {}
|
||||
|
||||
def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return a chunk of 20 dummy actions."""
|
||||
batch_size = len(observation["observation.state"])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
def __init__(self):
|
||||
self.config = self._Config()
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
# The server calls `policy.to(device)`. This stub ignores it.
|
||||
return self
|
||||
|
||||
def model(self, batch: dict) -> torch.Tensor:
|
||||
# Return a chunk of 20 dummy actions.
|
||||
batch_size = len(batch["robot_type"])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@require_package("grpc")
|
||||
def policy_server():
|
||||
"""Fresh `PolicyServer` instance with a stubbed-out policy model."""
|
||||
# Import only when the test actually runs (after decorator check)
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig
|
||||
from lerobot.scripts.server.policy_server import PolicyServer
|
||||
|
||||
test_config = PolicyServerConfig(host="localhost", port=9999)
|
||||
server = PolicyServer(test_config)
|
||||
# Replace the real policy with our fast, deterministic stub.
|
||||
server.policy = MockPolicy()
|
||||
server.actions_per_chunk = 20
|
||||
server.device = "cpu"
|
||||
|
||||
# Add mock lerobot_features that the observation similarity functions need
|
||||
server.lerobot_features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": [6],
|
||||
"names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"],
|
||||
}
|
||||
}
|
||||
|
||||
return server
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helper utilities for tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_obs(state: torch.Tensor, timestep: int = 0, must_go: bool = False):
|
||||
"""Create a TimedObservation with a given state vector."""
|
||||
# Import only when needed
|
||||
from lerobot.scripts.server.helpers import TimedObservation
|
||||
|
||||
return TimedObservation(
|
||||
observation={
|
||||
"joint1": state[0].item() if len(state) > 0 else 0.0,
|
||||
"joint2": state[1].item() if len(state) > 1 else 0.0,
|
||||
"joint3": state[2].item() if len(state) > 2 else 0.0,
|
||||
"joint4": state[3].item() if len(state) > 3 else 0.0,
|
||||
"joint5": state[4].item() if len(state) > 4 else 0.0,
|
||||
"joint6": state[5].item() if len(state) > 5 else 0.0,
|
||||
},
|
||||
timestamp=time.time(),
|
||||
timestep=timestep,
|
||||
must_go=must_go,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_time_action_chunk(policy_server):
|
||||
"""Verify that `_time_action_chunk` assigns correct timestamps and timesteps."""
|
||||
start_ts = time.time()
|
||||
start_t = 10
|
||||
# A chunk of 3 action tensors.
|
||||
action_tensors = [torch.randn(6) for _ in range(3)]
|
||||
|
||||
timed_actions = policy_server._time_action_chunk(start_ts, action_tensors, start_t)
|
||||
|
||||
assert len(timed_actions) == 3
|
||||
# Check timesteps
|
||||
assert [ta.get_timestep() for ta in timed_actions] == [10, 11, 12]
|
||||
# Check timestamps
|
||||
expected_timestamps = [
|
||||
start_ts,
|
||||
start_ts + policy_server.config.environment_dt,
|
||||
start_ts + 2 * policy_server.config.environment_dt,
|
||||
]
|
||||
for ta, expected_ts in zip(timed_actions, expected_timestamps, strict=True):
|
||||
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
|
||||
|
||||
|
||||
def test_maybe_enqueue_observation_must_go(policy_server):
|
||||
"""An observation with `must_go=True` is always enqueued."""
|
||||
obs = _make_obs(torch.zeros(6), must_go=True)
|
||||
assert policy_server._enqueue_observation(obs) is True
|
||||
assert policy_server.observation_queue.qsize() == 1
|
||||
assert policy_server.observation_queue.get_nowait() is obs
|
||||
|
||||
|
||||
def test_maybe_enqueue_observation_dissimilar(policy_server):
|
||||
"""A dissimilar observation (not `must_go`) is enqueued."""
|
||||
# Set a last predicted observation.
|
||||
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
|
||||
# Create a new, dissimilar observation.
|
||||
new_obs = _make_obs(torch.ones(6) * 5) # High norm difference
|
||||
|
||||
assert policy_server._enqueue_observation(new_obs) is True
|
||||
assert policy_server.observation_queue.qsize() == 1
|
||||
|
||||
|
||||
def test_maybe_enqueue_observation_is_skipped(policy_server):
|
||||
"""A similar observation (not `must_go`) is skipped."""
|
||||
# Set a last predicted observation.
|
||||
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
|
||||
# Create a new, very similar observation.
|
||||
new_obs = _make_obs(torch.zeros(6) + 1e-4)
|
||||
|
||||
assert policy_server._enqueue_observation(new_obs) is False
|
||||
assert policy_server.observation_queue.empty() is True
|
||||
|
||||
|
||||
def test_obs_sanity_checks(policy_server):
|
||||
"""Unit-test the private `_obs_sanity_checks` helper."""
|
||||
prev = _make_obs(torch.zeros(6), timestep=0)
|
||||
|
||||
# Case 1 – timestep already predicted
|
||||
policy_server._predicted_timesteps.add(1)
|
||||
obs_same_ts = _make_obs(torch.ones(6), timestep=1)
|
||||
assert policy_server._obs_sanity_checks(obs_same_ts, prev) is False
|
||||
|
||||
# Case 2 – observation too similar
|
||||
policy_server._predicted_timesteps.clear()
|
||||
obs_similar = _make_obs(torch.zeros(6) + 1e-4, timestep=2)
|
||||
assert policy_server._obs_sanity_checks(obs_similar, prev) is False
|
||||
|
||||
# Case 3 – genuinely new & dissimilar observation passes
|
||||
obs_ok = _make_obs(torch.ones(6) * 5, timestep=3)
|
||||
assert policy_server._obs_sanity_checks(obs_ok, prev) is True
|
||||
|
||||
|
||||
def test_predict_action_chunk(monkeypatch, policy_server):
|
||||
"""End-to-end test of `_predict_action_chunk` with a stubbed _get_action_chunk."""
|
||||
# Import only when needed
|
||||
from lerobot.scripts.server.policy_server import PolicyServer
|
||||
|
||||
# Force server to act-style policy; patch method to return deterministic tensor
|
||||
policy_server.policy_type = "act"
|
||||
action_dim = 6
|
||||
batch_size = 1
|
||||
actions_per_chunk = policy_server.actions_per_chunk
|
||||
|
||||
def _fake_get_action_chunk(_self, _obs, _type="act"):
|
||||
return torch.zeros(batch_size, actions_per_chunk, action_dim)
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
|
||||
|
||||
obs = _make_obs(torch.zeros(6), timestep=5)
|
||||
timed_actions = policy_server._predict_action_chunk(obs)
|
||||
|
||||
assert len(timed_actions) == actions_per_chunk
|
||||
assert [ta.get_timestep() for ta in timed_actions] == list(range(5, 5 + actions_per_chunk))
|
||||
|
||||
for i, ta in enumerate(timed_actions):
|
||||
expected_ts = obs.get_timestamp() + i * policy_server.config.environment_dt
|
||||
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
|
||||
234
tests/async_inference/test_robot_client.py
Normal file
234
tests/async_inference/test_robot_client.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit-tests for the `RobotClient` action-queue logic (pure Python, no gRPC).
|
||||
|
||||
We monkey-patch `lerobot.common.robot_devices.robots.utils.make_robot` so that
|
||||
no real hardware is accessed. Only the queue-update mechanism is verified.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from queue import Queue
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip entire module if grpc is not available
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Test fixtures
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def robot_client():
|
||||
"""Fresh `RobotClient` instance for each test case (no threads started).
|
||||
Uses DummyRobot."""
|
||||
# Import only when the test actually runs (after decorator check)
|
||||
from lerobot.scripts.server.configs import RobotClientConfig
|
||||
from lerobot.scripts.server.robot_client import RobotClient
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
|
||||
test_config = MockRobotConfig()
|
||||
|
||||
# gRPC channel is not actually used in tests, so using a dummy address
|
||||
test_config = RobotClientConfig(
|
||||
robot=test_config,
|
||||
server_address="localhost:9999",
|
||||
policy_type="test",
|
||||
pretrained_name_or_path="test",
|
||||
actions_per_chunk=20,
|
||||
verify_robot_cameras=False,
|
||||
)
|
||||
|
||||
client = RobotClient(test_config)
|
||||
|
||||
# Initialize attributes that are normally set in start() method
|
||||
client.chunks_received = 0
|
||||
client.available_actions_size = []
|
||||
|
||||
yield client
|
||||
|
||||
if client.robot.is_connected:
|
||||
client.stop()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helper utilities for tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_actions(start_ts: float, start_t: int, count: int):
|
||||
"""Generate `count` consecutive TimedAction objects starting at timestep `start_t`."""
|
||||
from lerobot.scripts.server.helpers import TimedAction
|
||||
|
||||
fps = 30 # emulates most common frame-rate
|
||||
actions = []
|
||||
for i in range(count):
|
||||
timestep = start_t + i
|
||||
timestamp = start_ts + i * (1 / fps)
|
||||
action_tensor = torch.full((6,), timestep, dtype=torch.float32)
|
||||
actions.append(TimedAction(action=action_tensor, timestep=timestep, timestamp=timestamp))
|
||||
return actions
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_update_action_queue_discards_stale(robot_client):
|
||||
"""`_update_action_queue` must drop actions with `timestep` <= `latest_action`."""
|
||||
|
||||
# Pretend we already executed up to action #4
|
||||
robot_client.latest_action = 4
|
||||
|
||||
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
|
||||
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
|
||||
|
||||
robot_client._aggregate_action_queues(incoming)
|
||||
|
||||
# Extract timesteps from queue
|
||||
resulting_timesteps = [a.get_timestep() for a in robot_client.action_queue.queue]
|
||||
|
||||
assert resulting_timesteps == [5, 6, 7]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_old, weight_new",
|
||||
[
|
||||
(1.0, 0.0),
|
||||
(0.0, 1.0),
|
||||
(0.5, 0.5),
|
||||
(0.2, 0.8),
|
||||
(0.8, 0.2),
|
||||
(0.1, 0.9),
|
||||
(0.9, 0.1),
|
||||
],
|
||||
)
|
||||
def test_aggregate_action_queues_combines_actions_in_overlap(
|
||||
robot_client, weight_old: float, weight_new: float
|
||||
):
|
||||
"""`_aggregate_action_queues` must combine actions on overlapping timesteps according
|
||||
to the provided aggregate_fn, here tested with multiple coefficients."""
|
||||
from lerobot.scripts.server.helpers import TimedAction
|
||||
|
||||
robot_client.chunks_received = 0
|
||||
|
||||
# Pretend we already executed up to action #4, and queue contains actions for timesteps 5..6
|
||||
robot_client.latest_action = 4
|
||||
current_actions = _make_actions(
|
||||
start_ts=time.time(), start_t=5, count=2
|
||||
) # actions are [torch.ones(6), torch.ones(6), ...]
|
||||
current_actions = [
|
||||
TimedAction(action=10 * a.get_action(), timestep=a.get_timestep(), timestamp=a.get_timestamp())
|
||||
for a in current_actions
|
||||
]
|
||||
|
||||
for a in current_actions:
|
||||
robot_client.action_queue.put(a)
|
||||
|
||||
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
|
||||
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
|
||||
|
||||
overlap_timesteps = [5, 6] # properly tested in test_aggregate_action_queues_discards_stale
|
||||
nonoverlap_timesteps = [7]
|
||||
|
||||
robot_client._aggregate_action_queues(
|
||||
incoming, aggregate_fn=lambda x1, x2: weight_old * x1 + weight_new * x2
|
||||
)
|
||||
|
||||
queue_overlap_actions = []
|
||||
queue_non_overlap_actions = []
|
||||
for a in robot_client.action_queue.queue:
|
||||
if a.get_timestep() in overlap_timesteps:
|
||||
queue_overlap_actions.append(a)
|
||||
elif a.get_timestep() in nonoverlap_timesteps:
|
||||
queue_non_overlap_actions.append(a)
|
||||
|
||||
queue_overlap_actions = sorted(queue_overlap_actions, key=lambda x: x.get_timestep())
|
||||
queue_non_overlap_actions = sorted(queue_non_overlap_actions, key=lambda x: x.get_timestep())
|
||||
|
||||
assert torch.allclose(
|
||||
queue_overlap_actions[0].get_action(),
|
||||
weight_old * current_actions[0].get_action() + weight_new * incoming[-3].get_action(),
|
||||
)
|
||||
assert torch.allclose(
|
||||
queue_overlap_actions[1].get_action(),
|
||||
weight_old * current_actions[1].get_action() + weight_new * incoming[-2].get_action(),
|
||||
)
|
||||
assert torch.allclose(queue_non_overlap_actions[0].get_action(), incoming[-1].get_action())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_size, queue_len, expected",
|
||||
[
|
||||
(20, 12, False), # 12 / 20 = 0.6 > g=0.5 threshold, not ready to send
|
||||
(20, 8, True), # 8 / 20 = 0.4 <= g=0.5, ready to send
|
||||
(10, 5, True),
|
||||
(10, 6, False),
|
||||
],
|
||||
)
|
||||
def test_ready_to_send_observation(robot_client, chunk_size: int, queue_len: int, expected: bool):
|
||||
"""Validate `_ready_to_send_observation` ratio logic for various sizes."""
|
||||
|
||||
robot_client.action_chunk_size = chunk_size
|
||||
|
||||
# Clear any existing actions then fill with `queue_len` dummy entries ----
|
||||
robot_client.action_queue = Queue()
|
||||
|
||||
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
|
||||
for act in dummy_actions:
|
||||
robot_client.action_queue.put(act)
|
||||
|
||||
assert robot_client._ready_to_send_observation() is expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"g_threshold, expected",
|
||||
[
|
||||
# The condition is `queue_size / chunk_size <= g`.
|
||||
# Here, ratio = 6 / 10 = 0.6.
|
||||
(0.0, False), # 0.6 <= 0.0 is False
|
||||
(0.1, False),
|
||||
(0.2, False),
|
||||
(0.3, False),
|
||||
(0.4, False),
|
||||
(0.5, False),
|
||||
(0.6, True), # 0.6 <= 0.6 is True
|
||||
(0.7, True),
|
||||
(0.8, True),
|
||||
(0.9, True),
|
||||
(1.0, True),
|
||||
],
|
||||
)
|
||||
def test_ready_to_send_observation_with_varying_threshold(robot_client, g_threshold: float, expected: bool):
|
||||
"""Validate `_ready_to_send_observation` with fixed sizes and varying `g`."""
|
||||
# Fixed sizes for this test: ratio = 6 / 10 = 0.6
|
||||
chunk_size = 10
|
||||
queue_len = 6
|
||||
|
||||
robot_client.action_chunk_size = chunk_size
|
||||
# This is the parameter we are testing
|
||||
robot_client._chunk_size_threshold = g_threshold
|
||||
|
||||
# Fill queue with dummy actions
|
||||
robot_client.action_queue = Queue()
|
||||
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
|
||||
for act in dummy_actions:
|
||||
robot_client.action_queue.put(act)
|
||||
|
||||
assert robot_client._ready_to_send_observation() is expected
|
||||
Reference in New Issue
Block a user