Files
lerobot/tests/async_inference/test_policy_server.py
Francesco Capuano 30c161006d Add Async Inference (#1196)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-07-10 10:39:11 +02:00

216 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit-tests for the `PolicyServer` core logic.
Monkey-patch the `policy` attribute with a stub so that no real model inference is performed.
"""
from __future__ import annotations
import time
import pytest
import torch
from lerobot.configs.types import PolicyFeature
from tests.utils import require_package
# -----------------------------------------------------------------------------
# Test fixtures
# -----------------------------------------------------------------------------
class MockPolicy:
"""A minimal mock for an actual policy, returning zeros.
Refer to tests/policies for tests of the individual policies supported."""
class _Config:
robot_type = "dummy_robot"
@property
def image_features(self) -> dict[str, PolicyFeature]:
"""Empty image features since this test doesn't use images."""
return {}
def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return a chunk of 20 dummy actions."""
batch_size = len(observation["observation.state"])
return torch.zeros(batch_size, 20, 6)
def __init__(self):
self.config = self._Config()
def to(self, *args, **kwargs):
# The server calls `policy.to(device)`. This stub ignores it.
return self
def model(self, batch: dict) -> torch.Tensor:
# Return a chunk of 20 dummy actions.
batch_size = len(batch["robot_type"])
return torch.zeros(batch_size, 20, 6)
@pytest.fixture
@require_package("grpc")
def policy_server():
"""Fresh `PolicyServer` instance with a stubbed-out policy model."""
# Import only when the test actually runs (after decorator check)
from lerobot.scripts.server.configs import PolicyServerConfig
from lerobot.scripts.server.policy_server import PolicyServer
test_config = PolicyServerConfig(host="localhost", port=9999)
server = PolicyServer(test_config)
# Replace the real policy with our fast, deterministic stub.
server.policy = MockPolicy()
server.actions_per_chunk = 20
server.device = "cpu"
# Add mock lerobot_features that the observation similarity functions need
server.lerobot_features = {
"observation.state": {
"dtype": "float32",
"shape": [6],
"names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"],
}
}
return server
# -----------------------------------------------------------------------------
# Helper utilities for tests
# -----------------------------------------------------------------------------
def _make_obs(state: torch.Tensor, timestep: int = 0, must_go: bool = False):
"""Create a TimedObservation with a given state vector."""
# Import only when needed
from lerobot.scripts.server.helpers import TimedObservation
return TimedObservation(
observation={
"joint1": state[0].item() if len(state) > 0 else 0.0,
"joint2": state[1].item() if len(state) > 1 else 0.0,
"joint3": state[2].item() if len(state) > 2 else 0.0,
"joint4": state[3].item() if len(state) > 3 else 0.0,
"joint5": state[4].item() if len(state) > 4 else 0.0,
"joint6": state[5].item() if len(state) > 5 else 0.0,
},
timestamp=time.time(),
timestep=timestep,
must_go=must_go,
)
# -----------------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------------
def test_time_action_chunk(policy_server):
"""Verify that `_time_action_chunk` assigns correct timestamps and timesteps."""
start_ts = time.time()
start_t = 10
# A chunk of 3 action tensors.
action_tensors = [torch.randn(6) for _ in range(3)]
timed_actions = policy_server._time_action_chunk(start_ts, action_tensors, start_t)
assert len(timed_actions) == 3
# Check timesteps
assert [ta.get_timestep() for ta in timed_actions] == [10, 11, 12]
# Check timestamps
expected_timestamps = [
start_ts,
start_ts + policy_server.config.environment_dt,
start_ts + 2 * policy_server.config.environment_dt,
]
for ta, expected_ts in zip(timed_actions, expected_timestamps, strict=True):
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
def test_maybe_enqueue_observation_must_go(policy_server):
"""An observation with `must_go=True` is always enqueued."""
obs = _make_obs(torch.zeros(6), must_go=True)
assert policy_server._enqueue_observation(obs) is True
assert policy_server.observation_queue.qsize() == 1
assert policy_server.observation_queue.get_nowait() is obs
def test_maybe_enqueue_observation_dissimilar(policy_server):
"""A dissimilar observation (not `must_go`) is enqueued."""
# Set a last predicted observation.
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
# Create a new, dissimilar observation.
new_obs = _make_obs(torch.ones(6) * 5) # High norm difference
assert policy_server._enqueue_observation(new_obs) is True
assert policy_server.observation_queue.qsize() == 1
def test_maybe_enqueue_observation_is_skipped(policy_server):
"""A similar observation (not `must_go`) is skipped."""
# Set a last predicted observation.
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
# Create a new, very similar observation.
new_obs = _make_obs(torch.zeros(6) + 1e-4)
assert policy_server._enqueue_observation(new_obs) is False
assert policy_server.observation_queue.empty() is True
def test_obs_sanity_checks(policy_server):
"""Unit-test the private `_obs_sanity_checks` helper."""
prev = _make_obs(torch.zeros(6), timestep=0)
# Case 1 timestep already predicted
policy_server._predicted_timesteps.add(1)
obs_same_ts = _make_obs(torch.ones(6), timestep=1)
assert policy_server._obs_sanity_checks(obs_same_ts, prev) is False
# Case 2 observation too similar
policy_server._predicted_timesteps.clear()
obs_similar = _make_obs(torch.zeros(6) + 1e-4, timestep=2)
assert policy_server._obs_sanity_checks(obs_similar, prev) is False
# Case 3 genuinely new & dissimilar observation passes
obs_ok = _make_obs(torch.ones(6) * 5, timestep=3)
assert policy_server._obs_sanity_checks(obs_ok, prev) is True
def test_predict_action_chunk(monkeypatch, policy_server):
"""End-to-end test of `_predict_action_chunk` with a stubbed _get_action_chunk."""
# Import only when needed
from lerobot.scripts.server.policy_server import PolicyServer
# Force server to act-style policy; patch method to return deterministic tensor
policy_server.policy_type = "act"
action_dim = 6
batch_size = 1
actions_per_chunk = policy_server.actions_per_chunk
def _fake_get_action_chunk(_self, _obs, _type="act"):
return torch.zeros(batch_size, actions_per_chunk, action_dim)
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
obs = _make_obs(torch.zeros(6), timestep=5)
timed_actions = policy_server._predict_action_chunk(obs)
assert len(timed_actions) == actions_per_chunk
assert [ta.get_timestep() for ta in timed_actions] == list(range(5, 5 + actions_per_chunk))
for i, ta in enumerate(timed_actions):
expected_ts = obs.get_timestamp() + i * policy_server.config.environment_dt
assert abs(ta.get_timestamp() - expected_ts) < 1e-6