forked from tangger/lerobot
Add Async Inference (#1196)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
ce2b9724bf
commit
30c161006d
177
tests/async_inference/test_e2e.py
Normal file
177
tests/async_inference/test_e2e.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""End-to-end test of the asynchronous inference stack (client ↔ server).
|
||||
|
||||
This test spins up a lightweight gRPC `PolicyServer` instance with a stubbed
|
||||
policy network and launches a `RobotClient` that uses a `MockRobot`. The goal
|
||||
is to exercise the full communication loop:
|
||||
|
||||
1. Client sends policy specification → Server
|
||||
2. Client streams observations → Server
|
||||
3. Server streams action chunks → Client
|
||||
4. Client executes received actions
|
||||
|
||||
The test succeeds if at least one action is executed and the server records at
|
||||
least one predicted timestep - demonstrating that the gRPC round-trip works
|
||||
end-to-end using real (but lightweight) protocol messages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from concurrent import futures
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip entire module if grpc is not available
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# End-to-end test
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_async_inference_e2e(monkeypatch):
|
||||
"""Tests the full asynchronous inference pipeline."""
|
||||
# Import grpc-dependent modules inside the test function
|
||||
import grpc
|
||||
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig, RobotClientConfig
|
||||
from lerobot.scripts.server.helpers import map_robot_keys_to_lerobot_features
|
||||
from lerobot.scripts.server.policy_server import PolicyServer
|
||||
from lerobot.scripts.server.robot_client import RobotClient
|
||||
from lerobot.transport import (
|
||||
async_inference_pb2, # type: ignore
|
||||
async_inference_pb2_grpc, # type: ignore
|
||||
)
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
|
||||
# Create a stub policy similar to test_policy_server.py
|
||||
class MockPolicy:
|
||||
"""A minimal mock for an actual policy, returning zeros."""
|
||||
|
||||
class _Config:
|
||||
robot_type = "dummy_robot"
|
||||
|
||||
@property
|
||||
def image_features(self):
|
||||
"""Empty image features since this test doesn't use images."""
|
||||
return {}
|
||||
|
||||
def __init__(self):
|
||||
self.config = self._Config()
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def model(self, batch):
|
||||
# Return a chunk of 20 dummy actions.
|
||||
batch_size = len(batch["robot_type"])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Create PolicyServer instance with mock policy
|
||||
# ------------------------------------------------------------------
|
||||
policy_server_config = PolicyServerConfig(host="localhost", port=9999)
|
||||
policy_server = PolicyServer(policy_server_config)
|
||||
# Replace the real policy with our fast, deterministic stub.
|
||||
policy_server.policy = MockPolicy()
|
||||
policy_server.actions_per_chunk = 20
|
||||
policy_server.device = "cpu"
|
||||
|
||||
# Set up robot config and features
|
||||
robot_config = MockRobotConfig()
|
||||
mock_robot = make_robot_from_config(robot_config)
|
||||
|
||||
lerobot_features = map_robot_keys_to_lerobot_features(mock_robot)
|
||||
policy_server.lerobot_features = lerobot_features
|
||||
|
||||
# Force server to produce deterministic action chunks in test mode
|
||||
policy_server.policy_type = "act"
|
||||
|
||||
def _fake_get_action_chunk(_self, _obs, _type="test"):
|
||||
action_dim = 6
|
||||
batch_size = 1
|
||||
actions_per_chunk = policy_server.actions_per_chunk
|
||||
|
||||
return torch.zeros(batch_size, actions_per_chunk, action_dim)
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
|
||||
|
||||
# Bypass potentially heavy model loading inside SendPolicyInstructions
|
||||
def _fake_send_policy_instructions(self, request, context): # noqa: N802
|
||||
return async_inference_pb2.Empty()
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True)
|
||||
|
||||
# Build gRPC server running a PolicyServer
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server"))
|
||||
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server)
|
||||
|
||||
# Use the host/port specified in the fixture's config
|
||||
server_address = f"{policy_server.config.host}:{policy_server.config.port}"
|
||||
server.add_insecure_port(server_address)
|
||||
server.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Create a RobotClient around the MockRobot
|
||||
# ------------------------------------------------------------------
|
||||
client_config = RobotClientConfig(
|
||||
server_address=server_address,
|
||||
robot=robot_config,
|
||||
chunk_size_threshold=0.0,
|
||||
policy_type="test",
|
||||
pretrained_name_or_path="test",
|
||||
actions_per_chunk=20,
|
||||
verify_robot_cameras=False,
|
||||
)
|
||||
|
||||
client = RobotClient(client_config)
|
||||
assert client.start(), "Client failed initial handshake with the server"
|
||||
|
||||
# Track action chunks received without modifying RobotClient
|
||||
action_chunks_received = {"count": 0}
|
||||
original_aggregate = client._aggregate_action_queues
|
||||
|
||||
def counting_aggregate(*args, **kwargs):
|
||||
action_chunks_received["count"] += 1
|
||||
return original_aggregate(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)
|
||||
|
||||
# Start client threads
|
||||
action_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
control_thread = threading.Thread(target=client.control_loop, args=({"task": ""}), daemon=True)
|
||||
action_thread.start()
|
||||
control_thread.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. System exchanges a few messages
|
||||
# ------------------------------------------------------------------
|
||||
# Wait for 5 seconds
|
||||
server.wait_for_termination(timeout=5)
|
||||
|
||||
assert action_chunks_received["count"] > 0, "Client did not receive any action chunks"
|
||||
assert len(policy_server._predicted_timesteps) > 0, "Server did not record any predicted timesteps"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Stop the system
|
||||
# ------------------------------------------------------------------
|
||||
client.stop()
|
||||
action_thread.join()
|
||||
control_thread.join()
|
||||
policy_server.stop()
|
||||
server.stop(grace=None)
|
||||
459
tests/async_inference/test_helpers.py
Normal file
459
tests/async_inference/test_helpers.py
Normal file
@@ -0,0 +1,459 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import pickle
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.scripts.server.helpers import (
|
||||
FPSTracker,
|
||||
TimedAction,
|
||||
TimedObservation,
|
||||
observations_similar,
|
||||
prepare_image,
|
||||
prepare_raw_observation,
|
||||
raw_observation_to_observation,
|
||||
resize_robot_observation_image,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# FPSTracker
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_fps_tracker_first_observation():
|
||||
"""First observation should initialize timestamp and return 0 FPS."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
timestamp = 1000.0
|
||||
|
||||
metrics = tracker.calculate_fps_metrics(timestamp)
|
||||
|
||||
assert tracker.first_timestamp == timestamp
|
||||
assert tracker.total_obs_count == 1
|
||||
assert metrics["avg_fps"] == 0.0
|
||||
assert metrics["target_fps"] == 30.0
|
||||
|
||||
|
||||
def test_fps_tracker_single_interval():
|
||||
"""Two observations 1 second apart should give 1 FPS."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
|
||||
# First observation at t=0
|
||||
metrics1 = tracker.calculate_fps_metrics(0.0)
|
||||
assert metrics1["avg_fps"] == 0.0
|
||||
|
||||
# Second observation at t=1 (1 second later)
|
||||
metrics2 = tracker.calculate_fps_metrics(1.0)
|
||||
expected_fps = 1.0 # (2-1) observations / 1.0 seconds = 1 FPS
|
||||
assert math.isclose(metrics2["avg_fps"], expected_fps, rel_tol=1e-6)
|
||||
|
||||
|
||||
def test_fps_tracker_multiple_intervals():
|
||||
"""Multiple observations should calculate correct average FPS."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
|
||||
# Simulate 5 observations over 2 seconds (should be 2 FPS average)
|
||||
timestamps = [0.0, 0.5, 1.0, 1.5, 2.0]
|
||||
|
||||
for i, ts in enumerate(timestamps):
|
||||
metrics = tracker.calculate_fps_metrics(ts)
|
||||
|
||||
if i == 0:
|
||||
assert metrics["avg_fps"] == 0.0
|
||||
elif i == len(timestamps) - 1:
|
||||
# After 5 observations over 2 seconds: (5-1)/2 = 2 FPS
|
||||
expected_fps = 2.0
|
||||
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
|
||||
|
||||
|
||||
def test_fps_tracker_irregular_intervals():
|
||||
"""FPS calculation should work with irregular time intervals."""
|
||||
tracker = FPSTracker(target_fps=30.0)
|
||||
|
||||
# Irregular timestamps: 0, 0.1, 0.5, 2.0, 3.0 seconds
|
||||
timestamps = [0.0, 0.1, 0.5, 2.0, 3.0]
|
||||
|
||||
for ts in timestamps:
|
||||
metrics = tracker.calculate_fps_metrics(ts)
|
||||
|
||||
# 5 observations over 3 seconds: (5-1)/3 = 1.333... FPS
|
||||
expected_fps = 4.0 / 3.0
|
||||
assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# TimedData helpers
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_timed_action_getters():
|
||||
"""TimedAction stores & returns timestamp, action tensor and timestep."""
|
||||
ts = time.time()
|
||||
action = torch.arange(10)
|
||||
ta = TimedAction(timestamp=ts, action=action, timestep=0)
|
||||
|
||||
assert math.isclose(ta.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
torch.testing.assert_close(ta.get_action(), action)
|
||||
assert ta.get_timestep() == 0
|
||||
|
||||
|
||||
def test_timed_observation_getters():
|
||||
"""TimedObservation stores & returns timestamp, dict and timestep."""
|
||||
ts = time.time()
|
||||
obs_dict = {"observation.state": torch.ones(6)}
|
||||
to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0)
|
||||
|
||||
assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
assert to.get_observation() is obs_dict
|
||||
assert to.get_timestep() == 0
|
||||
|
||||
|
||||
def test_timed_data_deserialization_data_getters():
|
||||
"""TimedAction / TimedObservation survive a round-trip through ``pickle``.
|
||||
|
||||
The async-inference stack uses ``pickle.dumps`` to move these objects across
|
||||
the gRPC boundary (see RobotClient.send_observation and PolicyServer.StreamActions).
|
||||
This test ensures that the payload keeps its content intact after
|
||||
the (de)serialization round-trip.
|
||||
"""
|
||||
ts = time.time()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TimedAction
|
||||
# ------------------------------------------------------------------
|
||||
original_action = torch.randn(6)
|
||||
ta_in = TimedAction(timestamp=ts, action=original_action, timestep=13)
|
||||
|
||||
# Serialize → bytes → deserialize
|
||||
ta_bytes = pickle.dumps(ta_in) # nosec
|
||||
ta_out: TimedAction = pickle.loads(ta_bytes) # nosec B301
|
||||
|
||||
# Identity & content checks
|
||||
assert math.isclose(ta_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
assert ta_out.get_timestep() == 13
|
||||
torch.testing.assert_close(ta_out.get_action(), original_action)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# TimedObservation
|
||||
# ------------------------------------------------------------------
|
||||
obs_dict = {"observation.state": torch.arange(4).float()}
|
||||
to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True)
|
||||
|
||||
to_bytes = pickle.dumps(to_in) # nosec
|
||||
to_out: TimedObservation = pickle.loads(to_bytes) # nosec B301
|
||||
|
||||
assert math.isclose(to_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
assert to_out.get_timestep() == 7
|
||||
assert to_out.must_go is True
|
||||
assert to_out.get_observation().keys() == obs_dict.keys()
|
||||
torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# observations_similar()
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_obs(state: torch.Tensor) -> TimedObservation:
|
||||
"""Create a TimedObservation with raw robot observation format."""
|
||||
return TimedObservation(
|
||||
timestamp=time.time(),
|
||||
observation={
|
||||
"shoulder": state[0].item() if len(state) > 0 else 0.0,
|
||||
"elbow": state[1].item() if len(state) > 1 else 0.0,
|
||||
"wrist": state[2].item() if len(state) > 2 else 0.0,
|
||||
"gripper": state[3].item() if len(state) > 3 else 0.0,
|
||||
},
|
||||
timestep=0,
|
||||
)
|
||||
|
||||
|
||||
def test_observations_similar_true():
|
||||
"""Distance below atol → observations considered similar."""
|
||||
# Create mock lerobot features for the similarity check
|
||||
lerobot_features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
}
|
||||
}
|
||||
|
||||
obs1 = _make_obs(torch.zeros(4))
|
||||
obs2 = _make_obs(0.5 * torch.ones(4))
|
||||
assert observations_similar(obs1, obs2, lerobot_features, atol=2.0)
|
||||
|
||||
obs3 = _make_obs(2.0 * torch.ones(4))
|
||||
assert not observations_similar(obs1, obs3, lerobot_features, atol=2.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# raw_observation_to_observation and helpers
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_mock_robot_observation():
|
||||
"""Create a mock robot observation with motor positions and camera images."""
|
||||
return {
|
||||
"shoulder": 1.0,
|
||||
"elbow": 2.0,
|
||||
"wrist": 3.0,
|
||||
"gripper": 0.5,
|
||||
"laptop": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
|
||||
"phone": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8),
|
||||
}
|
||||
|
||||
|
||||
def _create_mock_lerobot_features():
|
||||
"""Create mock lerobot features mapping similar to what hw_to_dataset_features returns."""
|
||||
return {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
},
|
||||
"observation.images.laptop": {
|
||||
"dtype": "image",
|
||||
"shape": [480, 640, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"observation.images.phone": {
|
||||
"dtype": "image",
|
||||
"shape": [480, 640, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _create_mock_policy_image_features():
|
||||
"""Create mock policy image features with different resolutions."""
|
||||
return {
|
||||
"observation.images.laptop": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224), # Policy expects smaller resolution
|
||||
),
|
||||
"observation.images.phone": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 160, 160), # Different resolution for second camera
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_prepare_image():
|
||||
"""Test image preprocessing: int8 → float32, normalization to [0,1]."""
|
||||
# Create mock int8 image data
|
||||
image_int8 = torch.randint(0, 256, size=(3, 224, 224), dtype=torch.uint8)
|
||||
|
||||
processed = prepare_image(image_int8)
|
||||
|
||||
# Check dtype conversion
|
||||
assert processed.dtype == torch.float32
|
||||
|
||||
# Check normalization range
|
||||
assert processed.min() >= 0.0
|
||||
assert processed.max() <= 1.0
|
||||
|
||||
# Check that values are scaled correctly (255 → 1.0, 0 → 0.0)
|
||||
if image_int8.max() == 255:
|
||||
assert torch.isclose(processed.max(), torch.tensor(1.0), atol=1e-6)
|
||||
if image_int8.min() == 0:
|
||||
assert torch.isclose(processed.min(), torch.tensor(0.0), atol=1e-6)
|
||||
|
||||
# Check memory contiguity
|
||||
assert processed.is_contiguous()
|
||||
|
||||
|
||||
def test_resize_robot_observation_image():
|
||||
"""Test image resizing from robot resolution to policy resolution."""
|
||||
# Create mock image: (H=480, W=640, C=3)
|
||||
original_image = torch.randint(0, 256, size=(480, 640, 3), dtype=torch.uint8)
|
||||
target_shape = (3, 224, 224) # (C, H, W)
|
||||
|
||||
resized = resize_robot_observation_image(original_image, target_shape)
|
||||
|
||||
# Check output shape matches target
|
||||
assert resized.shape == target_shape
|
||||
|
||||
# Check that original image had different dimensions
|
||||
assert original_image.shape != resized.shape
|
||||
|
||||
# Check that resizing preserves value range
|
||||
assert resized.min() >= 0
|
||||
assert resized.max() <= 255
|
||||
|
||||
|
||||
def test_prepare_raw_observation():
|
||||
"""Test the preparation of raw robot observation to lerobot format."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that state is properly extracted and batched
|
||||
assert "observation.state" in prepared
|
||||
state = prepared["observation.state"]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.shape == (1, 4) # Batched state
|
||||
|
||||
# Check that images are processed and resized
|
||||
assert "observation.images.laptop" in prepared
|
||||
assert "observation.images.phone" in prepared
|
||||
|
||||
laptop_img = prepared["observation.images.laptop"]
|
||||
phone_img = prepared["observation.images.phone"]
|
||||
|
||||
# Check image shapes match policy requirements
|
||||
assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape
|
||||
assert phone_img.shape == policy_image_features["observation.images.phone"].shape
|
||||
|
||||
# Check that images are tensors
|
||||
assert isinstance(laptop_img, torch.Tensor)
|
||||
assert isinstance(phone_img, torch.Tensor)
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_basic():
|
||||
"""Test the main raw_observation_to_observation function."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
device = "cpu"
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
|
||||
# Check that all expected keys are present
|
||||
assert "observation.state" in observation
|
||||
assert "observation.images.laptop" in observation
|
||||
assert "observation.images.phone" in observation
|
||||
|
||||
# Check state processing
|
||||
state = observation["observation.state"]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.device.type == device
|
||||
assert state.shape == (1, 4) # Batched
|
||||
|
||||
# Check image processing
|
||||
laptop_img = observation["observation.images.laptop"]
|
||||
phone_img = observation["observation.images.phone"]
|
||||
|
||||
# Images should have batch dimension: (B, C, H, W)
|
||||
assert laptop_img.shape == (1, 3, 224, 224)
|
||||
assert phone_img.shape == (1, 3, 160, 160)
|
||||
|
||||
# Check device placement
|
||||
assert laptop_img.device.type == device
|
||||
assert phone_img.device.type == device
|
||||
|
||||
# Check image dtype and range (should be float32 in [0, 1])
|
||||
assert laptop_img.dtype == torch.float32
|
||||
assert phone_img.dtype == torch.float32
|
||||
assert laptop_img.min() >= 0.0 and laptop_img.max() <= 1.0
|
||||
assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_with_non_tensor_data():
|
||||
"""Test that non-tensor data (like task strings) is preserved."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
robot_obs["task"] = "pick up the red cube" # Add string instruction
|
||||
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
device = "cpu"
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
|
||||
# Check that task string is preserved
|
||||
assert "task" in observation
|
||||
assert observation["task"] == "pick up the red cube"
|
||||
assert isinstance(observation["task"], str)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_raw_observation_to_observation_device_handling():
|
||||
"""Test that tensors are properly moved to the specified device."""
|
||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
|
||||
# Check that all tensors are on the correct device
|
||||
for key, value in observation.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert value.device.type == device, f"Tensor {key} not on {device}"
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_deterministic():
|
||||
"""Test that the function produces consistent results for the same input."""
|
||||
robot_obs = _create_mock_robot_observation()
|
||||
lerobot_features = _create_mock_lerobot_features()
|
||||
policy_image_features = _create_mock_policy_image_features()
|
||||
device = "cpu"
|
||||
|
||||
# Run twice with same input
|
||||
obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
|
||||
# Results should be identical
|
||||
assert set(obs1.keys()) == set(obs2.keys())
|
||||
|
||||
for key in obs1:
|
||||
if isinstance(obs1[key], torch.Tensor):
|
||||
torch.testing.assert_close(obs1[key], obs2[key])
|
||||
else:
|
||||
assert obs1[key] == obs2[key]
|
||||
|
||||
|
||||
def test_image_processing_pipeline_preserves_content():
|
||||
"""Test that the image processing pipeline preserves recognizable patterns."""
|
||||
# Create an image with a specific pattern
|
||||
original_img = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
original_img[25:75, 25:75, :] = 255 # White square in center
|
||||
|
||||
robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img}
|
||||
lerobot_features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
},
|
||||
"observation.images.laptop": {
|
||||
"dtype": "image",
|
||||
"shape": [100, 100, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
}
|
||||
policy_image_features = {
|
||||
"observation.images.laptop": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 50, 50), # Downsamples from 100x100
|
||||
)
|
||||
}
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu")
|
||||
|
||||
processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim
|
||||
|
||||
# Check that the center region has higher values than corners
|
||||
# Due to bilinear interpolation, exact values will change but pattern should remain
|
||||
center_val = processed_img[:, 25, 25].mean() # Center of 50x50 image
|
||||
corner_val = processed_img[:, 5, 5].mean() # Corner
|
||||
|
||||
assert center_val > corner_val, "Image processing should preserve recognizable patterns"
|
||||
215
tests/async_inference/test_policy_server.py
Normal file
215
tests/async_inference/test_policy_server.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit-tests for the `PolicyServer` core logic.
|
||||
Monkey-patch the `policy` attribute with a stub so that no real model inference is performed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from tests.utils import require_package
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Test fixtures
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockPolicy:
|
||||
"""A minimal mock for an actual policy, returning zeros.
|
||||
Refer to tests/policies for tests of the individual policies supported."""
|
||||
|
||||
class _Config:
|
||||
robot_type = "dummy_robot"
|
||||
|
||||
@property
|
||||
def image_features(self) -> dict[str, PolicyFeature]:
|
||||
"""Empty image features since this test doesn't use images."""
|
||||
return {}
|
||||
|
||||
def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return a chunk of 20 dummy actions."""
|
||||
batch_size = len(observation["observation.state"])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
def __init__(self):
|
||||
self.config = self._Config()
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
# The server calls `policy.to(device)`. This stub ignores it.
|
||||
return self
|
||||
|
||||
def model(self, batch: dict) -> torch.Tensor:
|
||||
# Return a chunk of 20 dummy actions.
|
||||
batch_size = len(batch["robot_type"])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@require_package("grpc")
|
||||
def policy_server():
|
||||
"""Fresh `PolicyServer` instance with a stubbed-out policy model."""
|
||||
# Import only when the test actually runs (after decorator check)
|
||||
from lerobot.scripts.server.configs import PolicyServerConfig
|
||||
from lerobot.scripts.server.policy_server import PolicyServer
|
||||
|
||||
test_config = PolicyServerConfig(host="localhost", port=9999)
|
||||
server = PolicyServer(test_config)
|
||||
# Replace the real policy with our fast, deterministic stub.
|
||||
server.policy = MockPolicy()
|
||||
server.actions_per_chunk = 20
|
||||
server.device = "cpu"
|
||||
|
||||
# Add mock lerobot_features that the observation similarity functions need
|
||||
server.lerobot_features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": [6],
|
||||
"names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"],
|
||||
}
|
||||
}
|
||||
|
||||
return server
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helper utilities for tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_obs(state: torch.Tensor, timestep: int = 0, must_go: bool = False):
|
||||
"""Create a TimedObservation with a given state vector."""
|
||||
# Import only when needed
|
||||
from lerobot.scripts.server.helpers import TimedObservation
|
||||
|
||||
return TimedObservation(
|
||||
observation={
|
||||
"joint1": state[0].item() if len(state) > 0 else 0.0,
|
||||
"joint2": state[1].item() if len(state) > 1 else 0.0,
|
||||
"joint3": state[2].item() if len(state) > 2 else 0.0,
|
||||
"joint4": state[3].item() if len(state) > 3 else 0.0,
|
||||
"joint5": state[4].item() if len(state) > 4 else 0.0,
|
||||
"joint6": state[5].item() if len(state) > 5 else 0.0,
|
||||
},
|
||||
timestamp=time.time(),
|
||||
timestep=timestep,
|
||||
must_go=must_go,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_time_action_chunk(policy_server):
|
||||
"""Verify that `_time_action_chunk` assigns correct timestamps and timesteps."""
|
||||
start_ts = time.time()
|
||||
start_t = 10
|
||||
# A chunk of 3 action tensors.
|
||||
action_tensors = [torch.randn(6) for _ in range(3)]
|
||||
|
||||
timed_actions = policy_server._time_action_chunk(start_ts, action_tensors, start_t)
|
||||
|
||||
assert len(timed_actions) == 3
|
||||
# Check timesteps
|
||||
assert [ta.get_timestep() for ta in timed_actions] == [10, 11, 12]
|
||||
# Check timestamps
|
||||
expected_timestamps = [
|
||||
start_ts,
|
||||
start_ts + policy_server.config.environment_dt,
|
||||
start_ts + 2 * policy_server.config.environment_dt,
|
||||
]
|
||||
for ta, expected_ts in zip(timed_actions, expected_timestamps, strict=True):
|
||||
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
|
||||
|
||||
|
||||
def test_maybe_enqueue_observation_must_go(policy_server):
|
||||
"""An observation with `must_go=True` is always enqueued."""
|
||||
obs = _make_obs(torch.zeros(6), must_go=True)
|
||||
assert policy_server._enqueue_observation(obs) is True
|
||||
assert policy_server.observation_queue.qsize() == 1
|
||||
assert policy_server.observation_queue.get_nowait() is obs
|
||||
|
||||
|
||||
def test_maybe_enqueue_observation_dissimilar(policy_server):
|
||||
"""A dissimilar observation (not `must_go`) is enqueued."""
|
||||
# Set a last predicted observation.
|
||||
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
|
||||
# Create a new, dissimilar observation.
|
||||
new_obs = _make_obs(torch.ones(6) * 5) # High norm difference
|
||||
|
||||
assert policy_server._enqueue_observation(new_obs) is True
|
||||
assert policy_server.observation_queue.qsize() == 1
|
||||
|
||||
|
||||
def test_maybe_enqueue_observation_is_skipped(policy_server):
|
||||
"""A similar observation (not `must_go`) is skipped."""
|
||||
# Set a last predicted observation.
|
||||
policy_server.last_processed_obs = _make_obs(torch.zeros(6))
|
||||
# Create a new, very similar observation.
|
||||
new_obs = _make_obs(torch.zeros(6) + 1e-4)
|
||||
|
||||
assert policy_server._enqueue_observation(new_obs) is False
|
||||
assert policy_server.observation_queue.empty() is True
|
||||
|
||||
|
||||
def test_obs_sanity_checks(policy_server):
|
||||
"""Unit-test the private `_obs_sanity_checks` helper."""
|
||||
prev = _make_obs(torch.zeros(6), timestep=0)
|
||||
|
||||
# Case 1 – timestep already predicted
|
||||
policy_server._predicted_timesteps.add(1)
|
||||
obs_same_ts = _make_obs(torch.ones(6), timestep=1)
|
||||
assert policy_server._obs_sanity_checks(obs_same_ts, prev) is False
|
||||
|
||||
# Case 2 – observation too similar
|
||||
policy_server._predicted_timesteps.clear()
|
||||
obs_similar = _make_obs(torch.zeros(6) + 1e-4, timestep=2)
|
||||
assert policy_server._obs_sanity_checks(obs_similar, prev) is False
|
||||
|
||||
# Case 3 – genuinely new & dissimilar observation passes
|
||||
obs_ok = _make_obs(torch.ones(6) * 5, timestep=3)
|
||||
assert policy_server._obs_sanity_checks(obs_ok, prev) is True
|
||||
|
||||
|
||||
def test_predict_action_chunk(monkeypatch, policy_server):
|
||||
"""End-to-end test of `_predict_action_chunk` with a stubbed _get_action_chunk."""
|
||||
# Import only when needed
|
||||
from lerobot.scripts.server.policy_server import PolicyServer
|
||||
|
||||
# Force server to act-style policy; patch method to return deterministic tensor
|
||||
policy_server.policy_type = "act"
|
||||
action_dim = 6
|
||||
batch_size = 1
|
||||
actions_per_chunk = policy_server.actions_per_chunk
|
||||
|
||||
def _fake_get_action_chunk(_self, _obs, _type="act"):
|
||||
return torch.zeros(batch_size, actions_per_chunk, action_dim)
|
||||
|
||||
monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True)
|
||||
|
||||
obs = _make_obs(torch.zeros(6), timestep=5)
|
||||
timed_actions = policy_server._predict_action_chunk(obs)
|
||||
|
||||
assert len(timed_actions) == actions_per_chunk
|
||||
assert [ta.get_timestep() for ta in timed_actions] == list(range(5, 5 + actions_per_chunk))
|
||||
|
||||
for i, ta in enumerate(timed_actions):
|
||||
expected_ts = obs.get_timestamp() + i * policy_server.config.environment_dt
|
||||
assert abs(ta.get_timestamp() - expected_ts) < 1e-6
|
||||
234
tests/async_inference/test_robot_client.py
Normal file
234
tests/async_inference/test_robot_client.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit-tests for the `RobotClient` action-queue logic (pure Python, no gRPC).
|
||||
|
||||
We monkey-patch `lerobot.common.robot_devices.robots.utils.make_robot` so that
|
||||
no real hardware is accessed. Only the queue-update mechanism is verified.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from queue import Queue
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip entire module if grpc is not available
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Test fixtures
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def robot_client():
|
||||
"""Fresh `RobotClient` instance for each test case (no threads started).
|
||||
Uses DummyRobot."""
|
||||
# Import only when the test actually runs (after decorator check)
|
||||
from lerobot.scripts.server.configs import RobotClientConfig
|
||||
from lerobot.scripts.server.robot_client import RobotClient
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
|
||||
test_config = MockRobotConfig()
|
||||
|
||||
# gRPC channel is not actually used in tests, so using a dummy address
|
||||
test_config = RobotClientConfig(
|
||||
robot=test_config,
|
||||
server_address="localhost:9999",
|
||||
policy_type="test",
|
||||
pretrained_name_or_path="test",
|
||||
actions_per_chunk=20,
|
||||
verify_robot_cameras=False,
|
||||
)
|
||||
|
||||
client = RobotClient(test_config)
|
||||
|
||||
# Initialize attributes that are normally set in start() method
|
||||
client.chunks_received = 0
|
||||
client.available_actions_size = []
|
||||
|
||||
yield client
|
||||
|
||||
if client.robot.is_connected:
|
||||
client.stop()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helper utilities for tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_actions(start_ts: float, start_t: int, count: int):
|
||||
"""Generate `count` consecutive TimedAction objects starting at timestep `start_t`."""
|
||||
from lerobot.scripts.server.helpers import TimedAction
|
||||
|
||||
fps = 30 # emulates most common frame-rate
|
||||
actions = []
|
||||
for i in range(count):
|
||||
timestep = start_t + i
|
||||
timestamp = start_ts + i * (1 / fps)
|
||||
action_tensor = torch.full((6,), timestep, dtype=torch.float32)
|
||||
actions.append(TimedAction(action=action_tensor, timestep=timestep, timestamp=timestamp))
|
||||
return actions
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_update_action_queue_discards_stale(robot_client):
|
||||
"""`_update_action_queue` must drop actions with `timestep` <= `latest_action`."""
|
||||
|
||||
# Pretend we already executed up to action #4
|
||||
robot_client.latest_action = 4
|
||||
|
||||
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
|
||||
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
|
||||
|
||||
robot_client._aggregate_action_queues(incoming)
|
||||
|
||||
# Extract timesteps from queue
|
||||
resulting_timesteps = [a.get_timestep() for a in robot_client.action_queue.queue]
|
||||
|
||||
assert resulting_timesteps == [5, 6, 7]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_old, weight_new",
|
||||
[
|
||||
(1.0, 0.0),
|
||||
(0.0, 1.0),
|
||||
(0.5, 0.5),
|
||||
(0.2, 0.8),
|
||||
(0.8, 0.2),
|
||||
(0.1, 0.9),
|
||||
(0.9, 0.1),
|
||||
],
|
||||
)
|
||||
def test_aggregate_action_queues_combines_actions_in_overlap(
|
||||
robot_client, weight_old: float, weight_new: float
|
||||
):
|
||||
"""`_aggregate_action_queues` must combine actions on overlapping timesteps according
|
||||
to the provided aggregate_fn, here tested with multiple coefficients."""
|
||||
from lerobot.scripts.server.helpers import TimedAction
|
||||
|
||||
robot_client.chunks_received = 0
|
||||
|
||||
# Pretend we already executed up to action #4, and queue contains actions for timesteps 5..6
|
||||
robot_client.latest_action = 4
|
||||
current_actions = _make_actions(
|
||||
start_ts=time.time(), start_t=5, count=2
|
||||
) # actions are [torch.ones(6), torch.ones(6), ...]
|
||||
current_actions = [
|
||||
TimedAction(action=10 * a.get_action(), timestep=a.get_timestep(), timestamp=a.get_timestamp())
|
||||
for a in current_actions
|
||||
]
|
||||
|
||||
for a in current_actions:
|
||||
robot_client.action_queue.put(a)
|
||||
|
||||
# Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept.
|
||||
incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7
|
||||
|
||||
overlap_timesteps = [5, 6] # properly tested in test_aggregate_action_queues_discards_stale
|
||||
nonoverlap_timesteps = [7]
|
||||
|
||||
robot_client._aggregate_action_queues(
|
||||
incoming, aggregate_fn=lambda x1, x2: weight_old * x1 + weight_new * x2
|
||||
)
|
||||
|
||||
queue_overlap_actions = []
|
||||
queue_non_overlap_actions = []
|
||||
for a in robot_client.action_queue.queue:
|
||||
if a.get_timestep() in overlap_timesteps:
|
||||
queue_overlap_actions.append(a)
|
||||
elif a.get_timestep() in nonoverlap_timesteps:
|
||||
queue_non_overlap_actions.append(a)
|
||||
|
||||
queue_overlap_actions = sorted(queue_overlap_actions, key=lambda x: x.get_timestep())
|
||||
queue_non_overlap_actions = sorted(queue_non_overlap_actions, key=lambda x: x.get_timestep())
|
||||
|
||||
assert torch.allclose(
|
||||
queue_overlap_actions[0].get_action(),
|
||||
weight_old * current_actions[0].get_action() + weight_new * incoming[-3].get_action(),
|
||||
)
|
||||
assert torch.allclose(
|
||||
queue_overlap_actions[1].get_action(),
|
||||
weight_old * current_actions[1].get_action() + weight_new * incoming[-2].get_action(),
|
||||
)
|
||||
assert torch.allclose(queue_non_overlap_actions[0].get_action(), incoming[-1].get_action())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_size, queue_len, expected",
|
||||
[
|
||||
(20, 12, False), # 12 / 20 = 0.6 > g=0.5 threshold, not ready to send
|
||||
(20, 8, True), # 8 / 20 = 0.4 <= g=0.5, ready to send
|
||||
(10, 5, True),
|
||||
(10, 6, False),
|
||||
],
|
||||
)
|
||||
def test_ready_to_send_observation(robot_client, chunk_size: int, queue_len: int, expected: bool):
|
||||
"""Validate `_ready_to_send_observation` ratio logic for various sizes."""
|
||||
|
||||
robot_client.action_chunk_size = chunk_size
|
||||
|
||||
# Clear any existing actions then fill with `queue_len` dummy entries ----
|
||||
robot_client.action_queue = Queue()
|
||||
|
||||
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
|
||||
for act in dummy_actions:
|
||||
robot_client.action_queue.put(act)
|
||||
|
||||
assert robot_client._ready_to_send_observation() is expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"g_threshold, expected",
|
||||
[
|
||||
# The condition is `queue_size / chunk_size <= g`.
|
||||
# Here, ratio = 6 / 10 = 0.6.
|
||||
(0.0, False), # 0.6 <= 0.0 is False
|
||||
(0.1, False),
|
||||
(0.2, False),
|
||||
(0.3, False),
|
||||
(0.4, False),
|
||||
(0.5, False),
|
||||
(0.6, True), # 0.6 <= 0.6 is True
|
||||
(0.7, True),
|
||||
(0.8, True),
|
||||
(0.9, True),
|
||||
(1.0, True),
|
||||
],
|
||||
)
|
||||
def test_ready_to_send_observation_with_varying_threshold(robot_client, g_threshold: float, expected: bool):
|
||||
"""Validate `_ready_to_send_observation` with fixed sizes and varying `g`."""
|
||||
# Fixed sizes for this test: ratio = 6 / 10 = 0.6
|
||||
chunk_size = 10
|
||||
queue_len = 6
|
||||
|
||||
robot_client.action_chunk_size = chunk_size
|
||||
# This is the parameter we are testing
|
||||
robot_client._chunk_size_threshold = g_threshold
|
||||
|
||||
# Fill queue with dummy actions
|
||||
robot_client.action_queue = Queue()
|
||||
dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len)
|
||||
for act in dummy_actions:
|
||||
robot_client.action_queue.put(act)
|
||||
|
||||
assert robot_client._ready_to_send_observation() is expected
|
||||
Reference in New Issue
Block a user