fix(async): Add pre and post processing to async inference and update docs (#2132)

* Add pre and post processing to async inference and update docs

* precommit fix typo

* fix tests

* refactor(async): no None branching for processors in _predict_action_chunk

---------

Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
Pepijn
2025-10-07 15:10:31 +02:00
committed by GitHub
parent fcaa0ea5f9
commit 9f32e00f90
8 changed files with 103 additions and 76 deletions

View File

@@ -91,6 +91,9 @@ def test_async_inference_e2e(monkeypatch):
policy_server.policy = MockPolicy()
policy_server.actions_per_chunk = 20
policy_server.device = "cpu"
# NOTE(Steven): Smelly tests as the Server is a state machine being partially mocked. Adding these processors as a quick fix.
policy_server.preprocessor = lambda obs: obs
policy_server.postprocessor = lambda tensor: tensor
# Set up robot config and features
robot_config = MockRobotConfig()

View File

@@ -333,9 +333,8 @@ def test_raw_observation_to_observation_basic():
robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
device = "cpu"
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
# Check that all expected keys are present
assert OBS_STATE in observation
@@ -345,7 +344,6 @@ def test_raw_observation_to_observation_basic():
# Check state processing
state = observation[OBS_STATE]
assert isinstance(state, torch.Tensor)
assert state.device.type == device
assert state.shape == (1, 4) # Batched
# Check image processing
@@ -356,10 +354,6 @@ def test_raw_observation_to_observation_basic():
assert laptop_img.shape == (1, 3, 224, 224)
assert phone_img.shape == (1, 3, 160, 160)
# Check device placement
assert laptop_img.device.type == device
assert phone_img.device.type == device
# Check image dtype and range (should be float32 in [0, 1])
assert laptop_img.dtype == torch.float32
assert phone_img.dtype == torch.float32
@@ -374,9 +368,8 @@ def test_raw_observation_to_observation_with_non_tensor_data():
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
device = "cpu"
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
# Check that task string is preserved
assert "task" in observation
@@ -386,19 +379,17 @@ def test_raw_observation_to_observation_with_non_tensor_data():
@torch.no_grad()
def test_raw_observation_to_observation_device_handling():
"""Test that tensors are properly moved to the specified device."""
device = "mps" if torch.backends.mps.is_available() else "cpu"
"""Test that tensors are created (device placement is handled by preprocessor)."""
robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
# Check that all tensors are on the correct device
# Check that all expected keys produce tensors (device placement handled by preprocessor later)
for key, value in observation.items():
if isinstance(value, torch.Tensor):
assert value.device.type == device, f"Tensor {key} not on {device}"
assert value.device.type in ["cpu", "cuda", "mps"], f"Tensor {key} on unexpected device"
def test_raw_observation_to_observation_deterministic():
@@ -406,11 +397,10 @@ def test_raw_observation_to_observation_deterministic():
robot_obs = _create_mock_robot_observation()
lerobot_features = _create_mock_lerobot_features()
policy_image_features = _create_mock_policy_image_features()
device = "cpu"
# Run twice with same input
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
# Results should be identical
assert set(obs1.keys()) == set(obs2.keys())
@@ -448,7 +438,7 @@ def test_image_processing_pipeline_preserves_content():
)
}
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu")
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features)
processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim

View File

@@ -196,6 +196,9 @@ def test_predict_action_chunk(monkeypatch, policy_server):
# Force server to act-style policy; patch method to return deterministic tensor
policy_server.policy_type = "act"
# NOTE(Steven): Smelly tests as the Server is a state machine being partially mocked. Adding these processors as a quick fix.
policy_server.preprocessor = lambda obs: obs
policy_server.postprocessor = lambda tensor: tensor
action_dim = 6
batch_size = 1
actions_per_chunk = policy_server.actions_per_chunk