Port HIL SERL (#644)

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Eugene Mironov <helper2424@gmail.com>
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
Co-authored-by: Ke Wang <superwk1017@gmail.com>
Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com>
Co-authored-by: imstevenpmwork <steven.palma@huggingface.co>
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
This commit is contained in:
Adil Zouitine
2025-06-13 13:15:47 +02:00
committed by GitHub
parent f976935ba1
commit d8079587a2
61 changed files with 14066 additions and 163 deletions

208
tests/rl/test_actor.py Normal file
View File

@@ -0,0 +1,208 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent import futures
from unittest.mock import patch
import pytest
import torch
from torch.multiprocessing import Event, Queue
from lerobot.common.utils.transition import Transition
from tests.utils import require_package
def create_learner_service_stub():
import grpc
from lerobot.common.transport import services_pb2, services_pb2_grpc
class MockLearnerService(services_pb2_grpc.LearnerServiceServicer):
def __init__(self):
self.ready_call_count = 0
self.should_fail = False
def Ready(self, request, context): # noqa: N802
self.ready_call_count += 1
if self.should_fail:
context.set_code(grpc.StatusCode.UNAVAILABLE)
context.set_details("Service unavailable")
raise grpc.RpcError("Service unavailable")
return services_pb2.Empty()
"""Fixture to start a LearnerService gRPC server and provide a connected stub."""
servicer = MockLearnerService()
# Create a gRPC server and add our servicer to it.
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server)
port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS
server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1}
# Create a client channel and stub connected to the server's port.
channel = grpc.insecure_channel(f"localhost:{port}")
return services_pb2_grpc.LearnerServiceStub(channel), servicer, channel, server
def close_service_stub(channel, server):
channel.close()
server.stop(None)
@require_package("grpc")
def test_establish_learner_connection_success():
from lerobot.scripts.rl.actor import establish_learner_connection
"""Test successful connection establishment."""
stub, _servicer, channel, server = create_learner_service_stub()
shutdown_event = Event()
# Test successful connection
result = establish_learner_connection(stub, shutdown_event, attempts=5)
assert result is True
close_service_stub(channel, server)
@require_package("grpc")
def test_establish_learner_connection_failure():
from lerobot.scripts.rl.actor import establish_learner_connection
"""Test connection failure."""
stub, servicer, channel, server = create_learner_service_stub()
servicer.should_fail = True
shutdown_event = Event()
# Test failed connection
with patch("time.sleep"): # Speed up the test
result = establish_learner_connection(stub, shutdown_event, attempts=2)
assert result is False
close_service_stub(channel, server)
@require_package("grpc")
def test_push_transitions_to_transport_queue():
from lerobot.common.transport.utils import bytes_to_transitions
from lerobot.scripts.rl.actor import push_transitions_to_transport_queue
from tests.transport.test_transport_utils import assert_transitions_equal
"""Test pushing transitions to transport queue."""
# Create mock transitions
transitions = []
for i in range(3):
transition = Transition(
state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
action=torch.randn(5),
reward=torch.tensor(1.0 + i),
done=torch.tensor(False),
truncated=torch.tensor(False),
next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
complementary_info={"step": torch.tensor(i)},
)
transitions.append(transition)
transitions_queue = Queue()
# Test pushing transitions
push_transitions_to_transport_queue(transitions, transitions_queue)
# Verify the data can be retrieved
serialized_data = transitions_queue.get()
assert isinstance(serialized_data, bytes)
deserialized_transitions = bytes_to_transitions(serialized_data)
assert len(deserialized_transitions) == len(transitions)
for i, deserialized_transition in enumerate(deserialized_transitions):
assert_transitions_equal(deserialized_transition, transitions[i])
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_transitions_stream():
from lerobot.scripts.rl.actor import transitions_stream
"""Test transitions stream functionality."""
shutdown_event = Event()
transitions_queue = Queue()
# Add test data to queue
test_data = [b"transition_data_1", b"transition_data_2", b"transition_data_3"]
for data in test_data:
transitions_queue.put(data)
# Collect streamed data
streamed_data = []
stream_generator = transitions_stream(shutdown_event, transitions_queue, 0.1)
# Process a few items
for i, message in enumerate(stream_generator):
streamed_data.append(message)
if i >= len(test_data) - 1:
shutdown_event.set()
break
# Verify we got messages
assert len(streamed_data) == len(test_data)
assert streamed_data[0].data == b"transition_data_1"
assert streamed_data[1].data == b"transition_data_2"
assert streamed_data[2].data == b"transition_data_3"
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_interactions_stream():
from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes
from lerobot.scripts.rl.actor import interactions_stream
"""Test interactions stream functionality."""
shutdown_event = Event()
interactions_queue = Queue()
# Create test interaction data (similar structure to what would be sent)
test_interactions = [
{"episode_reward": 10.5, "step": 1, "policy_fps": 30.2},
{"episode_reward": 15.2, "step": 2, "policy_fps": 28.7},
{"episode_reward": 8.7, "step": 3, "policy_fps": 29.1},
]
# Serialize the interaction data as it would be in practice
test_data = [
interactions_queue.put(python_object_to_bytes(interaction)) for interaction in test_interactions
]
# Collect streamed data
streamed_data = []
stream_generator = interactions_stream(shutdown_event, interactions_queue, 0.1)
# Process the items
for i, message in enumerate(stream_generator):
streamed_data.append(message)
if i >= len(test_data) - 1:
shutdown_event.set()
break
# Verify we got messages
assert len(streamed_data) == len(test_data)
# Verify the messages can be deserialized back to original data
for i, message in enumerate(streamed_data):
deserialized_interaction = bytes_to_python_object(message.data)
assert deserialized_interaction == test_interactions[i]

View File

@@ -0,0 +1,297 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
import threading
import time
import pytest
import torch
from torch.multiprocessing import Event, Queue
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.utils.transition import Transition
from lerobot.configs.train import TrainRLServerPipelineConfig
from tests.utils import require_package
def create_test_transitions(count: int = 3) -> list[Transition]:
"""Create test transitions for integration testing."""
transitions = []
for i in range(count):
transition = Transition(
state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
action=torch.randn(5),
reward=torch.tensor(1.0 + i),
done=torch.tensor(i == count - 1), # Last transition is done
truncated=torch.tensor(False),
next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
complementary_info={"step": torch.tensor(i), "episode_id": i // 2},
)
transitions.append(transition)
return transitions
def create_test_interactions(count: int = 3) -> list[dict]:
"""Create test interactions for integration testing."""
interactions = []
for i in range(count):
interaction = {
"episode_reward": 10.0 + i * 5,
"step": i * 100,
"policy_fps": 30.0 + i,
"intervention_rate": 0.1 * i,
"episode_length": 200 + i * 50,
}
interactions.append(interaction)
return interactions
def find_free_port():
"""Finds a free port on the local machine."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) # Bind to port 0 to let the OS choose a free port
s.listen(1)
port = s.getsockname()[1]
return port
@pytest.fixture
def cfg():
cfg = TrainRLServerPipelineConfig()
port = find_free_port()
policy_cfg = SACConfig()
policy_cfg.actor_learner_config.learner_host = "127.0.0.1"
policy_cfg.actor_learner_config.learner_port = port
policy_cfg.concurrency.actor = "threads"
policy_cfg.concurrency.learner = "threads"
policy_cfg.actor_learner_config.queue_get_timeout = 0.1
cfg.policy = policy_cfg
return cfg
@require_package("grpc")
@pytest.mark.timeout(10) # force cross-platform watchdog
def test_end_to_end_transitions_flow(cfg):
from lerobot.common.transport.utils import bytes_to_transitions
from lerobot.scripts.rl.actor import (
establish_learner_connection,
learner_service_client,
push_transitions_to_transport_queue,
send_transitions,
)
from lerobot.scripts.rl.learner import start_learner
from tests.transport.test_transport_utils import assert_transitions_equal
"""Test complete transitions flow from actor to learner."""
transitions_actor_queue = Queue()
transitions_learner_queue = Queue()
interactions_queue = Queue()
parameters_queue = Queue()
shutdown_event = Event()
learner_thread = threading.Thread(
target=start_learner,
args=(parameters_queue, transitions_learner_queue, interactions_queue, shutdown_event, cfg),
)
learner_thread.start()
policy_cfg = cfg.policy
learner_client, channel = learner_service_client(
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
)
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
send_transitions_thread = threading.Thread(
target=send_transitions, args=(cfg, transitions_actor_queue, shutdown_event, learner_client, channel)
)
send_transitions_thread.start()
input_transitions = create_test_transitions(count=5)
push_transitions_to_transport_queue(input_transitions, transitions_actor_queue)
# Wait for learner to start
time.sleep(0.1)
shutdown_event.set()
# Wait for learner to receive transitions
learner_thread.join()
send_transitions_thread.join()
channel.close()
received_transitions = []
while not transitions_learner_queue.empty():
received_transitions.extend(bytes_to_transitions(transitions_learner_queue.get()))
assert len(received_transitions) == len(input_transitions)
for i, transition in enumerate(received_transitions):
assert_transitions_equal(transition, input_transitions[i])
@require_package("grpc")
@pytest.mark.timeout(10)
def test_end_to_end_interactions_flow(cfg):
from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes
from lerobot.scripts.rl.actor import (
establish_learner_connection,
learner_service_client,
send_interactions,
)
from lerobot.scripts.rl.learner import start_learner
"""Test complete interactions flow from actor to learner."""
# Queues for actor-learner communication
interactions_actor_queue = Queue()
interactions_learner_queue = Queue()
# Other queues required by the learner
parameters_queue = Queue()
transitions_learner_queue = Queue()
shutdown_event = Event()
# Start the learner in a separate thread
learner_thread = threading.Thread(
target=start_learner,
args=(parameters_queue, transitions_learner_queue, interactions_learner_queue, shutdown_event, cfg),
)
learner_thread.start()
# Establish connection from actor to learner
policy_cfg = cfg.policy
learner_client, channel = learner_service_client(
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
)
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
# Start the actor's interaction sending process in a separate thread
send_interactions_thread = threading.Thread(
target=send_interactions,
args=(cfg, interactions_actor_queue, shutdown_event, learner_client, channel),
)
send_interactions_thread.start()
# Create and push test interactions to the actor's queue
input_interactions = create_test_interactions(count=5)
for interaction in input_interactions:
interactions_actor_queue.put(python_object_to_bytes(interaction))
# Wait for the communication to happen
time.sleep(0.1)
# Signal shutdown and wait for threads to complete
shutdown_event.set()
learner_thread.join()
send_interactions_thread.join()
channel.close()
# Verify that the learner received the interactions
received_interactions = []
while not interactions_learner_queue.empty():
received_interactions.append(bytes_to_python_object(interactions_learner_queue.get()))
assert len(received_interactions) == len(input_interactions)
# Sort by a unique key to handle potential reordering in queues
received_interactions.sort(key=lambda x: x["step"])
input_interactions.sort(key=lambda x: x["step"])
for received, expected in zip(received_interactions, input_interactions, strict=False):
assert received == expected
@require_package("grpc")
@pytest.mark.parametrize("data_size", ["small", "large"])
@pytest.mark.timeout(10)
def test_end_to_end_parameters_flow(cfg, data_size):
from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes
from lerobot.scripts.rl.actor import establish_learner_connection, learner_service_client, receive_policy
from lerobot.scripts.rl.learner import start_learner
"""Test complete parameter flow from learner to actor, with small and large data."""
# Actor's local queue to receive params
parameters_actor_queue = Queue()
# Learner's queue to send params from
parameters_learner_queue = Queue()
# Other queues required by the learner
transitions_learner_queue = Queue()
interactions_learner_queue = Queue()
shutdown_event = Event()
# Start the learner in a separate thread
learner_thread = threading.Thread(
target=start_learner,
args=(
parameters_learner_queue,
transitions_learner_queue,
interactions_learner_queue,
shutdown_event,
cfg,
),
)
learner_thread.start()
# Establish connection from actor to learner
policy_cfg = cfg.policy
learner_client, channel = learner_service_client(
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
)
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
# Start the actor's parameter receiving process in a separate thread
receive_params_thread = threading.Thread(
target=receive_policy,
args=(cfg, parameters_actor_queue, shutdown_event, learner_client, channel),
)
receive_params_thread.start()
# Create test parameters based on parametrization
if data_size == "small":
input_params = {"layer.weight": torch.randn(128, 64)}
else: # "large"
# CHUNK_SIZE is 2MB, so this tensor (4MB) will force chunking
input_params = {"large_layer.weight": torch.randn(1024, 1024)}
# Simulate learner having new parameters to send
parameters_learner_queue.put(state_to_bytes(input_params))
# Wait for the actor to receive the parameters
time.sleep(0.1)
# Signal shutdown and wait for threads to complete
shutdown_event.set()
learner_thread.join()
receive_params_thread.join()
channel.close()
# Verify that the actor received the parameters correctly
received_params = bytes_to_state_dict(parameters_actor_queue.get())
assert received_params.keys() == input_params.keys()
for key in input_params:
assert torch.allclose(received_params[key], input_params[key])

View File

@@ -0,0 +1,374 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from concurrent import futures
from multiprocessing import Event, Queue
import pytest
from tests.utils import require_package # our gRPC servicer class
@pytest.fixture(scope="function")
def learner_service_stub():
shutdown_event = Event()
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
seconds_between_pushes = 1
client, channel, server = create_learner_service_stub(
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
)
yield client # provide the stub to the test function
close_learner_service_stub(channel, server)
@require_package("grpc")
def create_learner_service_stub(
shutdown_event: Event,
parameters_queue: Queue,
transitions_queue: Queue,
interactions_queue: Queue,
seconds_between_pushes: int,
queue_get_timeout: float = 0.1,
):
import grpc
from lerobot.common.transport import services_pb2_grpc # generated from .proto
from lerobot.scripts.rl.learner_service import LearnerService
"""Fixture to start a LearnerService gRPC server and provide a connected stub."""
servicer = LearnerService(
shutdown_event=shutdown_event,
parameters_queue=parameters_queue,
seconds_between_pushes=seconds_between_pushes,
transition_queue=transitions_queue,
interaction_message_queue=interactions_queue,
queue_get_timeout=queue_get_timeout,
)
# Create a gRPC server and add our servicer to it.
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server)
port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS
server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1}
# Create a client channel and stub connected to the server's port.
channel = grpc.insecure_channel(f"localhost:{port}")
return services_pb2_grpc.LearnerServiceStub(channel), channel, server
@require_package("grpc")
def close_learner_service_stub(channel, server):
channel.close()
server.stop(None)
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_ready_method(learner_service_stub):
from lerobot.common.transport import services_pb2
"""Test the ready method of the UserService."""
request = services_pb2.Empty()
response = learner_service_stub.Ready(request)
assert response == services_pb2.Empty()
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_interactions():
from lerobot.common.transport import services_pb2
shutdown_event = Event()
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
seconds_between_pushes = 1
client, channel, server = create_learner_service_stub(
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
)
list_of_interaction_messages = [
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"1"),
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"2"),
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"3"),
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"4"),
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"5"),
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"6"),
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"7"),
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"8"),
]
def mock_intercations_stream():
yield from list_of_interaction_messages
return services_pb2.Empty()
response = client.SendInteractions(mock_intercations_stream())
assert response == services_pb2.Empty()
close_learner_service_stub(channel, server)
# Extract the data from the interactions queue
interactions = []
while not interactions_queue.empty():
interactions.append(interactions_queue.get())
assert interactions == [b"123", b"4", b"5", b"678"]
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_transitions():
from lerobot.common.transport import services_pb2
"""Test the SendTransitions method with various transition data."""
shutdown_event = Event()
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
seconds_between_pushes = 1
client, channel, server = create_learner_service_stub(
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
)
# Create test transition messages
list_of_transition_messages = [
services_pb2.Transition(
transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"transition_1"
),
services_pb2.Transition(
transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"transition_2"
),
services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"transition_3"),
services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"batch_1"),
services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"batch_2"),
]
def mock_transitions_stream():
yield from list_of_transition_messages
response = client.SendTransitions(mock_transitions_stream())
assert response == services_pb2.Empty()
close_learner_service_stub(channel, server)
# Extract the data from the transitions queue
transitions = []
while not transitions_queue.empty():
transitions.append(transitions_queue.get())
# Should have assembled the chunked data
assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"]
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_transitions_empty_stream():
from lerobot.common.transport import services_pb2
"""Test SendTransitions with empty stream."""
shutdown_event = Event()
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
seconds_between_pushes = 1
client, channel, server = create_learner_service_stub(
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
)
def empty_stream():
return iter([])
response = client.SendTransitions(empty_stream())
assert response == services_pb2.Empty()
close_learner_service_stub(channel, server)
# Queue should remain empty
assert transitions_queue.empty()
@require_package("grpc")
@pytest.mark.timeout(10) # force cross-platform watchdog
def test_stream_parameters():
import time
from lerobot.common.transport import services_pb2
"""Test the StreamParameters method."""
shutdown_event = Event()
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
seconds_between_pushes = 0.2 # Short delay for testing
client, channel, server = create_learner_service_stub(
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
)
# Add test parameters to the queue
test_params = [b"param_batch_1", b"param_batch_2"]
for param in test_params:
parameters_queue.put(param)
# Start streaming parameters
request = services_pb2.Empty()
stream = client.StreamParameters(request)
# Collect streamed parameters and timestamps
received_params = []
timestamps = []
for response in stream:
received_params.append(response.data)
timestamps.append(time.time())
# We should receive one last item
break
parameters_queue.put(b"param_batch_3")
for response in stream:
received_params.append(response.data)
timestamps.append(time.time())
# We should receive only one item
break
shutdown_event.set()
close_learner_service_stub(channel, server)
assert received_params == [b"param_batch_2", b"param_batch_3"]
# Check the time difference between the two sends
time_diff = timestamps[1] - timestamps[0]
# Check if the time difference is close to the expected push frequency
assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1)
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_stream_parameters_with_shutdown():
from lerobot.common.transport import services_pb2
"""Test StreamParameters handles shutdown gracefully."""
shutdown_event = Event()
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
seconds_between_pushes = 0.1
queue_get_timeout = 0.001
client, channel, server = create_learner_service_stub(
shutdown_event,
parameters_queue,
transitions_queue,
interactions_queue,
seconds_between_pushes,
queue_get_timeout=queue_get_timeout,
)
test_params = [b"param_batch_1", b"stop", b"param_batch_3", b"param_batch_4"]
# create a thread that will put the parameters in the queue
def producer():
for param in test_params:
parameters_queue.put(param)
time.sleep(0.1)
producer_thread = threading.Thread(target=producer)
producer_thread.start()
# Start streaming
request = services_pb2.Empty()
stream = client.StreamParameters(request)
# Collect streamed parameters
received_params = []
for response in stream:
received_params.append(response.data)
if response.data == b"stop":
shutdown_event.set()
producer_thread.join()
close_learner_service_stub(channel, server)
assert received_params == [b"param_batch_1", b"stop"]
@require_package("grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_stream_parameters_waits_and_retries_on_empty_queue():
import threading
import time
from lerobot.common.transport import services_pb2
"""Test that StreamParameters waits and retries when the queue is empty."""
shutdown_event = Event()
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
seconds_between_pushes = 0.05
queue_get_timeout = 0.01
client, channel, server = create_learner_service_stub(
shutdown_event,
parameters_queue,
transitions_queue,
interactions_queue,
seconds_between_pushes,
queue_get_timeout=queue_get_timeout,
)
request = services_pb2.Empty()
stream = client.StreamParameters(request)
received_params = []
def producer():
# Let the consumer start and find an empty queue.
# It will wait `seconds_between_pushes` (0.05s), then `get` will timeout after `queue_get_timeout` (0.01s).
# Total time for the first empty loop is > 0.06s. We wait a bit longer to be safe.
time.sleep(0.06)
parameters_queue.put(b"param_after_wait")
time.sleep(0.05)
parameters_queue.put(b"param_after_wait_2")
producer_thread = threading.Thread(target=producer)
producer_thread.start()
# The consumer will block here until the producer sends an item.
for response in stream:
received_params.append(response.data)
if response.data == b"param_after_wait_2":
break # We only need one item for this test.
shutdown_event.set()
producer_thread.join()
close_learner_service_stub(channel, server)
assert received_params == [b"param_after_wait", b"param_after_wait_2"]