fix(async): Add pre and post processing to async inference and update docs (#2132)

* Add pre and post processing to async inference and update docs

* precommit fix typo

* fix tests

* refactor(async): no None branching for processors in _predict_action_chunk

---------

Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
Pepijn
2025-10-07 15:10:31 +02:00
committed by GitHub
parent fcaa0ea5f9
commit 9f32e00f90
8 changed files with 103 additions and 76 deletions

View File

@@ -31,15 +31,15 @@ Then, spin up a policy server (in one terminal, or in a separate machine) specif
You can spin up a policy server running: You can spin up a policy server running:
```shell ```shell
python src/lerobot/async_inference/policy_server.py \ python -m lerobot.async_inference.policy_server \
--host=127.0.0.1 \ --host=127.0.0.1 \
--port=8080 \ --port=8080
``` ```
This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with: This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with:
```shell ```shell
python src/lerobot/async_inference/robot_client.py \ python -m lerobot.async_inference.robot_client \
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server --server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
--robot.type=so100_follower \ # ROBOT: your robot type --robot.type=so100_follower \ # ROBOT: your robot type
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port --robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
@@ -113,9 +113,9 @@ As such, spinning up a policy server is as easy as specifying the host address a
<hfoptions id="start_policy_server"> <hfoptions id="start_policy_server">
<hfoption id="Command"> <hfoption id="Command">
```bash ```bash
python -m lerobot.scripts.server.policy_server \ python -m lerobot.async_inference.policy_server \
--host="localhost" \ --host=127.0.0.1 \
--port=8080 --port=8080
``` ```
</hfoption> </hfoption>
<hfoption id="API example"> <hfoption id="API example">
@@ -148,7 +148,7 @@ The `RobotClient` streams observations to the `PolicyServer`, and receives actio
<hfoptions id="start_robot_client"> <hfoptions id="start_robot_client">
<hfoption id="Command"> <hfoption id="Command">
```bash ```bash
python src/lerobot/async_inference/robot_client.py \ python -m lerobot.async_inference.robot_client \
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server --server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
--robot.type=so100_follower \ # ROBOT: your robot type --robot.type=so100_follower \ # ROBOT: your robot type
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port --robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port

View File

@@ -26,4 +26,4 @@ DEFAULT_OBS_QUEUE_TIMEOUT = 2
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"] SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
# TODO: Add all other robots # TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"] SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"]

View File

@@ -92,11 +92,11 @@ def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int,
return resized.squeeze(0) return resized.squeeze(0)
# TODO(Steven): Consider implementing a pipeline step for this
def raw_observation_to_observation( def raw_observation_to_observation(
raw_observation: RawObservation, raw_observation: RawObservation,
lerobot_features: dict[str, dict], lerobot_features: dict[str, dict],
policy_image_features: dict[str, PolicyFeature], policy_image_features: dict[str, PolicyFeature],
device: str,
) -> Observation: ) -> Observation:
observation = {} observation = {}
@@ -105,9 +105,7 @@ def raw_observation_to_observation(
if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations
if "image" in k: if "image" in k:
# Policy expects images in shape (B, C, H, W) # Policy expects images in shape (B, C, H, W)
observation[k] = prepare_image(v).unsqueeze(0).to(device) observation[k] = prepare_image(v).unsqueeze(0)
else:
observation[k] = v.to(device)
else: else:
observation[k] = v observation[k] = v

View File

@@ -15,7 +15,7 @@
""" """
Example: Example:
```shell ```shell
python src/lerobot/async_inference/policy_server.py \ python -m lerobot.async_inference.policy_server \
--host=127.0.0.1 \ --host=127.0.0.1 \
--port=8080 \ --port=8080 \
--fps=30 \ --fps=30 \
@@ -32,12 +32,17 @@ from concurrent import futures
from dataclasses import asdict from dataclasses import asdict
from pprint import pformat from pprint import pformat
from queue import Empty, Queue from queue import Empty, Queue
from typing import Any
import draccus import draccus
import grpc import grpc
import torch import torch
from lerobot.policies.factory import get_policy_class from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.processor import (
PolicyAction,
PolicyProcessorPipeline,
)
from lerobot.transport import ( from lerobot.transport import (
services_pb2, # type: ignore services_pb2, # type: ignore
services_pb2_grpc, # type: ignore services_pb2_grpc, # type: ignore
@@ -82,6 +87,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
self.lerobot_features = None self.lerobot_features = None
self.actions_per_chunk = None self.actions_per_chunk = None
self.policy = None self.policy = None
self.preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None
self.postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None
@property @property
def running(self): def running(self):
@@ -146,6 +153,16 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
start = time.perf_counter() start = time.perf_counter()
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path) self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path)
self.policy.to(self.device) self.policy.to(self.device)
# Load preprocessor and postprocessor, overriding device to match requested device
device_override = {"device": self.device}
self.preprocessor, self.postprocessor = make_pre_post_processors(
self.policy.config,
pretrained_path=policy_specs.pretrained_name_or_path,
preprocessor_overrides={"device_processor": device_override},
postprocessor_overrides={"device_processor": device_override},
)
end = time.perf_counter() end = time.perf_counter()
self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds") self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds")
@@ -173,7 +190,7 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
# Calculate FPS metrics # Calculate FPS metrics
fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp) fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp)
self.logger.info( self.logger.debug(
f"Received observation #{obs_timestep} | " f"Received observation #{obs_timestep} | "
f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client
f"Target: {fps_metrics['target_fps']:.2f} | " f"Target: {fps_metrics['target_fps']:.2f} | "
@@ -189,7 +206,7 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
if not self._enqueue_observation( if not self._enqueue_observation(
timed_observation # wrapping a RawObservation timed_observation # wrapping a RawObservation
): ):
self.logger.info(f"Observation #{obs_timestep} has been filtered out") self.logger.debug(f"Observation #{obs_timestep} has been filtered out")
return services_pb2.Empty() return services_pb2.Empty()
@@ -301,23 +318,6 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
for i, action in enumerate(action_chunk) for i, action in enumerate(action_chunk)
] ]
def _prepare_observation(self, observation_t: TimedObservation) -> Observation:
"""
Prepare observation, ready for policy inference.
E.g.: To keep observation sampling rate high (and network packet tiny) we send int8 [0,255] images from the
client and then convert them to float32 [0,1] images here, before running inference.
"""
# RawObservation from robot.get_observation() - wrong keys, wrong dtype, wrong image shape
observation: Observation = raw_observation_to_observation(
observation_t.get_observation(),
self.lerobot_features,
self.policy_image_features,
self.device,
)
# processed Observation - right keys, right dtype, right image shape
return observation
def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
"""Get an action chunk from the policy. The chunk contains only""" """Get an action chunk from the policy. The chunk contains only"""
chunk = self.policy.predict_action_chunk(observation) chunk = self.policy.predict_action_chunk(observation)
@@ -327,44 +327,76 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
return chunk[:, : self.actions_per_chunk, :] return chunk[:, : self.actions_per_chunk, :]
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]: def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]:
"""Predict an action chunk based on an observation""" """Predict an action chunk based on an observation.
inference_starts = time.perf_counter()
Pipeline:
1. Convert raw observation to LeRobot format
2. Apply preprocessor (tokenization, normalization, batching, device placement)
3. Run policy inference to get action chunk
4. Apply postprocessor (unnormalization, device movement)
5. Convert to TimedAction list
"""
"""1. Prepare observation""" """1. Prepare observation"""
start_time = time.perf_counter() start_prepare = time.perf_counter()
observation = self._prepare_observation(observation_t) observation: Observation = raw_observation_to_observation(
preprocessing_time = time.perf_counter() - start_time observation_t.get_observation(),
self.lerobot_features,
self.policy_image_features,
)
prepare_time = time.perf_counter() - start_prepare
"""2. Apply preprocessor"""
start_preprocess = time.perf_counter()
observation = self.preprocessor(observation)
self.last_processed_obs: TimedObservation = observation_t self.last_processed_obs: TimedObservation = observation_t
preprocessing_time = time.perf_counter() - start_preprocess
"""2. Get action chunk""" """3. Get action chunk"""
start_time = time.perf_counter() start_inference = time.perf_counter()
action_tensor = self._get_action_chunk(observation) action_tensor = self._get_action_chunk(observation)
inference_time = time.perf_counter() - start_time inference_time = time.perf_counter() - start_inference
self.logger.info(
f"Preprocessing and inference took {inference_time:.4f}s, action shape: {action_tensor.shape}"
)
"""3. Post-inference processing""" """4. Apply postprocessor"""
start_time = time.perf_counter() # Apply postprocessor (handles unnormalization and device movement)
# Move to CPU before serializing # Postprocessor expects (B, action_dim) per action, but we have (B, chunk_size, action_dim)
action_tensor = action_tensor.cpu().squeeze(0) # So we process each action in the chunk individually
start_postprocess = time.perf_counter()
_, chunk_size, _ = action_tensor.shape
# Process each action in the chunk
processed_actions = []
for i in range(chunk_size):
# Extract action at timestep i: (B, action_dim)
single_action = action_tensor[:, i, :]
processed_action = self.postprocessor(single_action)
processed_actions.append(processed_action)
# Stack back to (B, chunk_size, action_dim), then remove batch dim
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
"""5. Convert to TimedAction list"""
action_chunk = self._time_action_chunk( action_chunk = self._time_action_chunk(
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep() observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
) )
postprocessing_time = time.perf_counter() - start_time postprocess_stops = time.perf_counter()
inference_stops = time.perf_counter() postprocessing_time = postprocess_stops - start_postprocess
self.logger.info( self.logger.info(
f"Observation {observation_t.get_timestep()} |" f"Observation {observation_t.get_timestep()} | "
f"Inference time: {1000 * (inference_stops - inference_starts):.2f}ms" f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
) )
# full-process latency breakdown for debugging purposes
self.logger.debug( self.logger.debug(
f"Observation {observation_t.get_timestep()} | " f"Observation {observation_t.get_timestep()} | "
f"Preprocessing time: {1000 * (preprocessing_time - inference_starts):.2f}ms | " f"Prepare time: {1000 * prepare_time:.2f}ms | "
f"Inference time: {1000 * (inference_time - preprocessing_time):.2f}ms | " f"Preprocessing time: {1000 * preprocessing_time:.2f}ms | "
f"Postprocessing time: {1000 * (postprocessing_time - inference_time):.2f}ms | " f"Inference time: {1000 * inference_time:.2f}ms | "
f"Total time: {1000 * (postprocessing_time - inference_starts):.2f}ms" f"Postprocessing time: {1000 * postprocessing_time:.2f}ms | "
f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms"
) )
return action_chunk return action_chunk

View File

@@ -52,6 +52,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.robots import ( # noqa: F401 from lerobot.robots import ( # noqa: F401
Robot, Robot,
RobotConfig, RobotConfig,
bi_so100_follower,
koch_follower, koch_follower,
make_robot_from_config, make_robot_from_config,
so100_follower, so100_follower,
@@ -214,7 +215,7 @@ class RobotClient:
) )
_ = self.stub.SendObservations(observation_iterator) _ = self.stub.SendObservations(observation_iterator)
obs_timestep = obs.get_timestep() obs_timestep = obs.get_timestep()
self.logger.info(f"Sent observation #{obs_timestep} | ") self.logger.debug(f"Sent observation #{obs_timestep} | ")
return True return True
@@ -467,7 +468,7 @@ class RobotClient:
if self._ready_to_send_observation(): if self._ready_to_send_observation():
_captured_observation = self.control_loop_observation(task, verbose) _captured_observation = self.control_loop_observation(task, verbose)
self.logger.info(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}") self.logger.debug(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}")
# Dynamically adjust sleep time to maintain the desired control frequency # Dynamically adjust sleep time to maintain the desired control frequency
time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start))) time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start)))

View File

@@ -91,6 +91,9 @@ def test_async_inference_e2e(monkeypatch):
policy_server.policy = MockPolicy() policy_server.policy = MockPolicy()
policy_server.actions_per_chunk = 20 policy_server.actions_per_chunk = 20
policy_server.device = "cpu" policy_server.device = "cpu"
# NOTE(Steven): Smelly tests as the Server is a state machine being partially mocked. Adding these processors as a quick fix.
policy_server.preprocessor = lambda obs: obs
policy_server.postprocessor = lambda tensor: tensor
# Set up robot config and features # Set up robot config and features
robot_config = MockRobotConfig() robot_config = MockRobotConfig()

View File

@@ -333,9 +333,8 @@ def test_raw_observation_to_observation_basic():
robot_obs = _create_mock_robot_observation() robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features() lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features() policy_image_features = _create_mock_policy_image_features()
device = "cpu"
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
# Check that all expected keys are present # Check that all expected keys are present
assert OBS_STATE in observation assert OBS_STATE in observation
@@ -345,7 +344,6 @@ def test_raw_observation_to_observation_basic():
# Check state processing # Check state processing
state = observation[OBS_STATE] state = observation[OBS_STATE]
assert isinstance(state, torch.Tensor) assert isinstance(state, torch.Tensor)
assert state.device.type == device
assert state.shape == (1, 4) # Batched assert state.shape == (1, 4) # Batched
# Check image processing # Check image processing
@@ -356,10 +354,6 @@ def test_raw_observation_to_observation_basic():
assert laptop_img.shape == (1, 3, 224, 224) assert laptop_img.shape == (1, 3, 224, 224)
assert phone_img.shape == (1, 3, 160, 160) assert phone_img.shape == (1, 3, 160, 160)
# Check device placement
assert laptop_img.device.type == device
assert phone_img.device.type == device
# Check image dtype and range (should be float32 in [0, 1]) # Check image dtype and range (should be float32 in [0, 1])
assert laptop_img.dtype == torch.float32 assert laptop_img.dtype == torch.float32
assert phone_img.dtype == torch.float32 assert phone_img.dtype == torch.float32
@@ -374,9 +368,8 @@ def test_raw_observation_to_observation_with_non_tensor_data():
lerobot_features = _create_mock_lerobot_features() lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features() policy_image_features = _create_mock_policy_image_features()
device = "cpu"
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
# Check that task string is preserved # Check that task string is preserved
assert "task" in observation assert "task" in observation
@@ -386,19 +379,17 @@ def test_raw_observation_to_observation_with_non_tensor_data():
@torch.no_grad() @torch.no_grad()
def test_raw_observation_to_observation_device_handling(): def test_raw_observation_to_observation_device_handling():
"""Test that tensors are properly moved to the specified device.""" """Test that tensors are created (device placement is handled by preprocessor)."""
device = "mps" if torch.backends.mps.is_available() else "cpu"
robot_obs = _create_mock_robot_observation() robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features() lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features() policy_image_features = _create_mock_policy_image_features()
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
# Check that all tensors are on the correct device # Check that all expected keys produce tensors (device placement handled by preprocessor later)
for key, value in observation.items(): for key, value in observation.items():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
assert value.device.type == device, f"Tensor {key} not on {device}" assert value.device.type in ["cpu", "cuda", "mps"], f"Tensor {key} on unexpected device"
def test_raw_observation_to_observation_deterministic(): def test_raw_observation_to_observation_deterministic():
@@ -406,11 +397,10 @@ def test_raw_observation_to_observation_deterministic():
robot_obs = _create_mock_robot_observation() robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features() lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features() policy_image_features = _create_mock_policy_image_features()
device = "cpu"
# Run twice with same input # Run twice with same input
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
# Results should be identical # Results should be identical
assert set(obs1.keys()) == set(obs2.keys()) assert set(obs1.keys()) == set(obs2.keys())
@@ -448,7 +438,7 @@ def test_image_processing_pipeline_preserves_content():
) )
} }
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu") observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim

View File

@@ -196,6 +196,9 @@ def test_predict_action_chunk(monkeypatch, policy_server):
# Force server to act-style policy; patch method to return deterministic tensor # Force server to act-style policy; patch method to return deterministic tensor
policy_server.policy_type = "act" policy_server.policy_type = "act"
# NOTE(Steven): Smelly tests as the Server is a state machine being partially mocked. Adding these processors as a quick fix.
policy_server.preprocessor = lambda obs: obs
policy_server.postprocessor = lambda tensor: tensor
action_dim = 6 action_dim = 6
batch_size = 1 batch_size = 1
actions_per_chunk = policy_server.actions_per_chunk actions_per_chunk = policy_server.actions_per_chunk