Port HIL SERL (#644)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Eugene Mironov <helper2424@gmail.com> Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com> Co-authored-by: Ke Wang <superwk1017@gmail.com> Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com> Co-authored-by: imstevenpmwork <steven.palma@huggingface.co> Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
This commit is contained in:
208
tests/rl/test_actor.py
Normal file
208
tests/rl/test_actor.py
Normal file
@@ -0,0 +1,208 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from concurrent import futures
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.common.utils.transition import Transition
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
def create_learner_service_stub():
|
||||
import grpc
|
||||
|
||||
from lerobot.common.transport import services_pb2, services_pb2_grpc
|
||||
|
||||
class MockLearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
def __init__(self):
|
||||
self.ready_call_count = 0
|
||||
self.should_fail = False
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
self.ready_call_count += 1
|
||||
if self.should_fail:
|
||||
context.set_code(grpc.StatusCode.UNAVAILABLE)
|
||||
context.set_details("Service unavailable")
|
||||
raise grpc.RpcError("Service unavailable")
|
||||
return services_pb2.Empty()
|
||||
|
||||
"""Fixture to start a LearnerService gRPC server and provide a connected stub."""
|
||||
|
||||
servicer = MockLearnerService()
|
||||
|
||||
# Create a gRPC server and add our servicer to it.
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||||
services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server)
|
||||
port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS
|
||||
server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1}
|
||||
|
||||
# Create a client channel and stub connected to the server's port.
|
||||
channel = grpc.insecure_channel(f"localhost:{port}")
|
||||
return services_pb2_grpc.LearnerServiceStub(channel), servicer, channel, server
|
||||
|
||||
|
||||
def close_service_stub(channel, server):
|
||||
channel.close()
|
||||
server.stop(None)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_establish_learner_connection_success():
|
||||
from lerobot.scripts.rl.actor import establish_learner_connection
|
||||
|
||||
"""Test successful connection establishment."""
|
||||
stub, _servicer, channel, server = create_learner_service_stub()
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
# Test successful connection
|
||||
result = establish_learner_connection(stub, shutdown_event, attempts=5)
|
||||
|
||||
assert result is True
|
||||
|
||||
close_service_stub(channel, server)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_establish_learner_connection_failure():
|
||||
from lerobot.scripts.rl.actor import establish_learner_connection
|
||||
|
||||
"""Test connection failure."""
|
||||
stub, servicer, channel, server = create_learner_service_stub()
|
||||
servicer.should_fail = True
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
# Test failed connection
|
||||
with patch("time.sleep"): # Speed up the test
|
||||
result = establish_learner_connection(stub, shutdown_event, attempts=2)
|
||||
|
||||
assert result is False
|
||||
|
||||
close_service_stub(channel, server)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_push_transitions_to_transport_queue():
|
||||
from lerobot.common.transport.utils import bytes_to_transitions
|
||||
from lerobot.scripts.rl.actor import push_transitions_to_transport_queue
|
||||
from tests.transport.test_transport_utils import assert_transitions_equal
|
||||
|
||||
"""Test pushing transitions to transport queue."""
|
||||
# Create mock transitions
|
||||
transitions = []
|
||||
for i in range(3):
|
||||
transition = Transition(
|
||||
state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
action=torch.randn(5),
|
||||
reward=torch.tensor(1.0 + i),
|
||||
done=torch.tensor(False),
|
||||
truncated=torch.tensor(False),
|
||||
next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
complementary_info={"step": torch.tensor(i)},
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
transitions_queue = Queue()
|
||||
|
||||
# Test pushing transitions
|
||||
push_transitions_to_transport_queue(transitions, transitions_queue)
|
||||
|
||||
# Verify the data can be retrieved
|
||||
serialized_data = transitions_queue.get()
|
||||
assert isinstance(serialized_data, bytes)
|
||||
deserialized_transitions = bytes_to_transitions(serialized_data)
|
||||
assert len(deserialized_transitions) == len(transitions)
|
||||
for i, deserialized_transition in enumerate(deserialized_transitions):
|
||||
assert_transitions_equal(deserialized_transition, transitions[i])
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_transitions_stream():
|
||||
from lerobot.scripts.rl.actor import transitions_stream
|
||||
|
||||
"""Test transitions stream functionality."""
|
||||
shutdown_event = Event()
|
||||
transitions_queue = Queue()
|
||||
|
||||
# Add test data to queue
|
||||
test_data = [b"transition_data_1", b"transition_data_2", b"transition_data_3"]
|
||||
for data in test_data:
|
||||
transitions_queue.put(data)
|
||||
|
||||
# Collect streamed data
|
||||
streamed_data = []
|
||||
stream_generator = transitions_stream(shutdown_event, transitions_queue, 0.1)
|
||||
|
||||
# Process a few items
|
||||
for i, message in enumerate(stream_generator):
|
||||
streamed_data.append(message)
|
||||
if i >= len(test_data) - 1:
|
||||
shutdown_event.set()
|
||||
break
|
||||
|
||||
# Verify we got messages
|
||||
assert len(streamed_data) == len(test_data)
|
||||
assert streamed_data[0].data == b"transition_data_1"
|
||||
assert streamed_data[1].data == b"transition_data_2"
|
||||
assert streamed_data[2].data == b"transition_data_3"
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_interactions_stream():
|
||||
from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
from lerobot.scripts.rl.actor import interactions_stream
|
||||
|
||||
"""Test interactions stream functionality."""
|
||||
shutdown_event = Event()
|
||||
interactions_queue = Queue()
|
||||
|
||||
# Create test interaction data (similar structure to what would be sent)
|
||||
test_interactions = [
|
||||
{"episode_reward": 10.5, "step": 1, "policy_fps": 30.2},
|
||||
{"episode_reward": 15.2, "step": 2, "policy_fps": 28.7},
|
||||
{"episode_reward": 8.7, "step": 3, "policy_fps": 29.1},
|
||||
]
|
||||
|
||||
# Serialize the interaction data as it would be in practice
|
||||
test_data = [
|
||||
interactions_queue.put(python_object_to_bytes(interaction)) for interaction in test_interactions
|
||||
]
|
||||
|
||||
# Collect streamed data
|
||||
streamed_data = []
|
||||
stream_generator = interactions_stream(shutdown_event, interactions_queue, 0.1)
|
||||
|
||||
# Process the items
|
||||
for i, message in enumerate(stream_generator):
|
||||
streamed_data.append(message)
|
||||
if i >= len(test_data) - 1:
|
||||
shutdown_event.set()
|
||||
break
|
||||
|
||||
# Verify we got messages
|
||||
assert len(streamed_data) == len(test_data)
|
||||
|
||||
# Verify the messages can be deserialized back to original data
|
||||
for i, message in enumerate(streamed_data):
|
||||
deserialized_interaction = bytes_to_python_object(message.data)
|
||||
assert deserialized_interaction == test_interactions[i]
|
||||
297
tests/rl/test_actor_learner.py
Normal file
297
tests/rl/test_actor_learner.py
Normal file
@@ -0,0 +1,297 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.utils.transition import Transition
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
def create_test_transitions(count: int = 3) -> list[Transition]:
|
||||
"""Create test transitions for integration testing."""
|
||||
transitions = []
|
||||
for i in range(count):
|
||||
transition = Transition(
|
||||
state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
action=torch.randn(5),
|
||||
reward=torch.tensor(1.0 + i),
|
||||
done=torch.tensor(i == count - 1), # Last transition is done
|
||||
truncated=torch.tensor(False),
|
||||
next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
complementary_info={"step": torch.tensor(i), "episode_id": i // 2},
|
||||
)
|
||||
transitions.append(transition)
|
||||
return transitions
|
||||
|
||||
|
||||
def create_test_interactions(count: int = 3) -> list[dict]:
|
||||
"""Create test interactions for integration testing."""
|
||||
interactions = []
|
||||
for i in range(count):
|
||||
interaction = {
|
||||
"episode_reward": 10.0 + i * 5,
|
||||
"step": i * 100,
|
||||
"policy_fps": 30.0 + i,
|
||||
"intervention_rate": 0.1 * i,
|
||||
"episode_length": 200 + i * 50,
|
||||
}
|
||||
interactions.append(interaction)
|
||||
return interactions
|
||||
|
||||
|
||||
def find_free_port():
|
||||
"""Finds a free port on the local machine."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0)) # Bind to port 0 to let the OS choose a free port
|
||||
s.listen(1)
|
||||
port = s.getsockname()[1]
|
||||
return port
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cfg():
|
||||
cfg = TrainRLServerPipelineConfig()
|
||||
|
||||
port = find_free_port()
|
||||
|
||||
policy_cfg = SACConfig()
|
||||
policy_cfg.actor_learner_config.learner_host = "127.0.0.1"
|
||||
policy_cfg.actor_learner_config.learner_port = port
|
||||
policy_cfg.concurrency.actor = "threads"
|
||||
policy_cfg.concurrency.learner = "threads"
|
||||
policy_cfg.actor_learner_config.queue_get_timeout = 0.1
|
||||
|
||||
cfg.policy = policy_cfg
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(10) # force cross-platform watchdog
|
||||
def test_end_to_end_transitions_flow(cfg):
|
||||
from lerobot.common.transport.utils import bytes_to_transitions
|
||||
from lerobot.scripts.rl.actor import (
|
||||
establish_learner_connection,
|
||||
learner_service_client,
|
||||
push_transitions_to_transport_queue,
|
||||
send_transitions,
|
||||
)
|
||||
from lerobot.scripts.rl.learner import start_learner
|
||||
from tests.transport.test_transport_utils import assert_transitions_equal
|
||||
|
||||
"""Test complete transitions flow from actor to learner."""
|
||||
transitions_actor_queue = Queue()
|
||||
transitions_learner_queue = Queue()
|
||||
|
||||
interactions_queue = Queue()
|
||||
parameters_queue = Queue()
|
||||
shutdown_event = Event()
|
||||
|
||||
learner_thread = threading.Thread(
|
||||
target=start_learner,
|
||||
args=(parameters_queue, transitions_learner_queue, interactions_queue, shutdown_event, cfg),
|
||||
)
|
||||
learner_thread.start()
|
||||
|
||||
policy_cfg = cfg.policy
|
||||
learner_client, channel = learner_service_client(
|
||||
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
|
||||
)
|
||||
|
||||
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
|
||||
|
||||
send_transitions_thread = threading.Thread(
|
||||
target=send_transitions, args=(cfg, transitions_actor_queue, shutdown_event, learner_client, channel)
|
||||
)
|
||||
send_transitions_thread.start()
|
||||
|
||||
input_transitions = create_test_transitions(count=5)
|
||||
|
||||
push_transitions_to_transport_queue(input_transitions, transitions_actor_queue)
|
||||
|
||||
# Wait for learner to start
|
||||
time.sleep(0.1)
|
||||
|
||||
shutdown_event.set()
|
||||
|
||||
# Wait for learner to receive transitions
|
||||
learner_thread.join()
|
||||
send_transitions_thread.join()
|
||||
channel.close()
|
||||
|
||||
received_transitions = []
|
||||
while not transitions_learner_queue.empty():
|
||||
received_transitions.extend(bytes_to_transitions(transitions_learner_queue.get()))
|
||||
|
||||
assert len(received_transitions) == len(input_transitions)
|
||||
for i, transition in enumerate(received_transitions):
|
||||
assert_transitions_equal(transition, input_transitions[i])
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(10)
|
||||
def test_end_to_end_interactions_flow(cfg):
|
||||
from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
from lerobot.scripts.rl.actor import (
|
||||
establish_learner_connection,
|
||||
learner_service_client,
|
||||
send_interactions,
|
||||
)
|
||||
from lerobot.scripts.rl.learner import start_learner
|
||||
|
||||
"""Test complete interactions flow from actor to learner."""
|
||||
# Queues for actor-learner communication
|
||||
interactions_actor_queue = Queue()
|
||||
interactions_learner_queue = Queue()
|
||||
|
||||
# Other queues required by the learner
|
||||
parameters_queue = Queue()
|
||||
transitions_learner_queue = Queue()
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
# Start the learner in a separate thread
|
||||
learner_thread = threading.Thread(
|
||||
target=start_learner,
|
||||
args=(parameters_queue, transitions_learner_queue, interactions_learner_queue, shutdown_event, cfg),
|
||||
)
|
||||
learner_thread.start()
|
||||
|
||||
# Establish connection from actor to learner
|
||||
policy_cfg = cfg.policy
|
||||
learner_client, channel = learner_service_client(
|
||||
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
|
||||
)
|
||||
|
||||
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
|
||||
|
||||
# Start the actor's interaction sending process in a separate thread
|
||||
send_interactions_thread = threading.Thread(
|
||||
target=send_interactions,
|
||||
args=(cfg, interactions_actor_queue, shutdown_event, learner_client, channel),
|
||||
)
|
||||
send_interactions_thread.start()
|
||||
|
||||
# Create and push test interactions to the actor's queue
|
||||
input_interactions = create_test_interactions(count=5)
|
||||
for interaction in input_interactions:
|
||||
interactions_actor_queue.put(python_object_to_bytes(interaction))
|
||||
|
||||
# Wait for the communication to happen
|
||||
time.sleep(0.1)
|
||||
|
||||
# Signal shutdown and wait for threads to complete
|
||||
shutdown_event.set()
|
||||
learner_thread.join()
|
||||
send_interactions_thread.join()
|
||||
channel.close()
|
||||
|
||||
# Verify that the learner received the interactions
|
||||
received_interactions = []
|
||||
while not interactions_learner_queue.empty():
|
||||
received_interactions.append(bytes_to_python_object(interactions_learner_queue.get()))
|
||||
|
||||
assert len(received_interactions) == len(input_interactions)
|
||||
|
||||
# Sort by a unique key to handle potential reordering in queues
|
||||
received_interactions.sort(key=lambda x: x["step"])
|
||||
input_interactions.sort(key=lambda x: x["step"])
|
||||
|
||||
for received, expected in zip(received_interactions, input_interactions, strict=False):
|
||||
assert received == expected
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.parametrize("data_size", ["small", "large"])
|
||||
@pytest.mark.timeout(10)
|
||||
def test_end_to_end_parameters_flow(cfg, data_size):
|
||||
from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
from lerobot.scripts.rl.actor import establish_learner_connection, learner_service_client, receive_policy
|
||||
from lerobot.scripts.rl.learner import start_learner
|
||||
|
||||
"""Test complete parameter flow from learner to actor, with small and large data."""
|
||||
# Actor's local queue to receive params
|
||||
parameters_actor_queue = Queue()
|
||||
# Learner's queue to send params from
|
||||
parameters_learner_queue = Queue()
|
||||
|
||||
# Other queues required by the learner
|
||||
transitions_learner_queue = Queue()
|
||||
interactions_learner_queue = Queue()
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
# Start the learner in a separate thread
|
||||
learner_thread = threading.Thread(
|
||||
target=start_learner,
|
||||
args=(
|
||||
parameters_learner_queue,
|
||||
transitions_learner_queue,
|
||||
interactions_learner_queue,
|
||||
shutdown_event,
|
||||
cfg,
|
||||
),
|
||||
)
|
||||
learner_thread.start()
|
||||
|
||||
# Establish connection from actor to learner
|
||||
policy_cfg = cfg.policy
|
||||
learner_client, channel = learner_service_client(
|
||||
host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port
|
||||
)
|
||||
|
||||
assert establish_learner_connection(learner_client, shutdown_event, attempts=5)
|
||||
|
||||
# Start the actor's parameter receiving process in a separate thread
|
||||
receive_params_thread = threading.Thread(
|
||||
target=receive_policy,
|
||||
args=(cfg, parameters_actor_queue, shutdown_event, learner_client, channel),
|
||||
)
|
||||
receive_params_thread.start()
|
||||
|
||||
# Create test parameters based on parametrization
|
||||
if data_size == "small":
|
||||
input_params = {"layer.weight": torch.randn(128, 64)}
|
||||
else: # "large"
|
||||
# CHUNK_SIZE is 2MB, so this tensor (4MB) will force chunking
|
||||
input_params = {"large_layer.weight": torch.randn(1024, 1024)}
|
||||
|
||||
# Simulate learner having new parameters to send
|
||||
parameters_learner_queue.put(state_to_bytes(input_params))
|
||||
|
||||
# Wait for the actor to receive the parameters
|
||||
time.sleep(0.1)
|
||||
|
||||
# Signal shutdown and wait for threads to complete
|
||||
shutdown_event.set()
|
||||
learner_thread.join()
|
||||
receive_params_thread.join()
|
||||
channel.close()
|
||||
|
||||
# Verify that the actor received the parameters correctly
|
||||
received_params = bytes_to_state_dict(parameters_actor_queue.get())
|
||||
|
||||
assert received_params.keys() == input_params.keys()
|
||||
for key in input_params:
|
||||
assert torch.allclose(received_params[key], input_params[key])
|
||||
374
tests/rl/test_learner_service.py
Normal file
374
tests/rl/test_learner_service.py
Normal file
@@ -0,0 +1,374 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import threading
|
||||
import time
|
||||
from concurrent import futures
|
||||
from multiprocessing import Event, Queue
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import require_package # our gRPC servicer class
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def learner_service_stub():
|
||||
shutdown_event = Event()
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
seconds_between_pushes = 1
|
||||
client, channel, server = create_learner_service_stub(
|
||||
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
|
||||
)
|
||||
|
||||
yield client # provide the stub to the test function
|
||||
|
||||
close_learner_service_stub(channel, server)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def create_learner_service_stub(
|
||||
shutdown_event: Event,
|
||||
parameters_queue: Queue,
|
||||
transitions_queue: Queue,
|
||||
interactions_queue: Queue,
|
||||
seconds_between_pushes: int,
|
||||
queue_get_timeout: float = 0.1,
|
||||
):
|
||||
import grpc
|
||||
|
||||
from lerobot.common.transport import services_pb2_grpc # generated from .proto
|
||||
from lerobot.scripts.rl.learner_service import LearnerService
|
||||
|
||||
"""Fixture to start a LearnerService gRPC server and provide a connected stub."""
|
||||
|
||||
servicer = LearnerService(
|
||||
shutdown_event=shutdown_event,
|
||||
parameters_queue=parameters_queue,
|
||||
seconds_between_pushes=seconds_between_pushes,
|
||||
transition_queue=transitions_queue,
|
||||
interaction_message_queue=interactions_queue,
|
||||
queue_get_timeout=queue_get_timeout,
|
||||
)
|
||||
|
||||
# Create a gRPC server and add our servicer to it.
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=4))
|
||||
services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server)
|
||||
port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS
|
||||
server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1}
|
||||
|
||||
# Create a client channel and stub connected to the server's port.
|
||||
channel = grpc.insecure_channel(f"localhost:{port}")
|
||||
return services_pb2_grpc.LearnerServiceStub(channel), channel, server
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def close_learner_service_stub(channel, server):
|
||||
channel.close()
|
||||
server.stop(None)
|
||||
|
||||
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_ready_method(learner_service_stub):
|
||||
from lerobot.common.transport import services_pb2
|
||||
|
||||
"""Test the ready method of the UserService."""
|
||||
request = services_pb2.Empty()
|
||||
response = learner_service_stub.Ready(request)
|
||||
assert response == services_pb2.Empty()
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_send_interactions():
|
||||
from lerobot.common.transport import services_pb2
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
seconds_between_pushes = 1
|
||||
client, channel, server = create_learner_service_stub(
|
||||
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
|
||||
)
|
||||
|
||||
list_of_interaction_messages = [
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"1"),
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"2"),
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"3"),
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"4"),
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"5"),
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"6"),
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"7"),
|
||||
services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"8"),
|
||||
]
|
||||
|
||||
def mock_intercations_stream():
|
||||
yield from list_of_interaction_messages
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
response = client.SendInteractions(mock_intercations_stream())
|
||||
assert response == services_pb2.Empty()
|
||||
|
||||
close_learner_service_stub(channel, server)
|
||||
|
||||
# Extract the data from the interactions queue
|
||||
interactions = []
|
||||
while not interactions_queue.empty():
|
||||
interactions.append(interactions_queue.get())
|
||||
|
||||
assert interactions == [b"123", b"4", b"5", b"678"]
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_send_transitions():
|
||||
from lerobot.common.transport import services_pb2
|
||||
|
||||
"""Test the SendTransitions method with various transition data."""
|
||||
shutdown_event = Event()
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
seconds_between_pushes = 1
|
||||
|
||||
client, channel, server = create_learner_service_stub(
|
||||
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
|
||||
)
|
||||
|
||||
# Create test transition messages
|
||||
list_of_transition_messages = [
|
||||
services_pb2.Transition(
|
||||
transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"transition_1"
|
||||
),
|
||||
services_pb2.Transition(
|
||||
transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"transition_2"
|
||||
),
|
||||
services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"transition_3"),
|
||||
services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"batch_1"),
|
||||
services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"batch_2"),
|
||||
]
|
||||
|
||||
def mock_transitions_stream():
|
||||
yield from list_of_transition_messages
|
||||
|
||||
response = client.SendTransitions(mock_transitions_stream())
|
||||
assert response == services_pb2.Empty()
|
||||
|
||||
close_learner_service_stub(channel, server)
|
||||
|
||||
# Extract the data from the transitions queue
|
||||
transitions = []
|
||||
while not transitions_queue.empty():
|
||||
transitions.append(transitions_queue.get())
|
||||
|
||||
# Should have assembled the chunked data
|
||||
assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"]
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_send_transitions_empty_stream():
|
||||
from lerobot.common.transport import services_pb2
|
||||
|
||||
"""Test SendTransitions with empty stream."""
|
||||
shutdown_event = Event()
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
seconds_between_pushes = 1
|
||||
|
||||
client, channel, server = create_learner_service_stub(
|
||||
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
|
||||
)
|
||||
|
||||
def empty_stream():
|
||||
return iter([])
|
||||
|
||||
response = client.SendTransitions(empty_stream())
|
||||
assert response == services_pb2.Empty()
|
||||
|
||||
close_learner_service_stub(channel, server)
|
||||
|
||||
# Queue should remain empty
|
||||
assert transitions_queue.empty()
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(10) # force cross-platform watchdog
|
||||
def test_stream_parameters():
|
||||
import time
|
||||
|
||||
from lerobot.common.transport import services_pb2
|
||||
|
||||
"""Test the StreamParameters method."""
|
||||
shutdown_event = Event()
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
seconds_between_pushes = 0.2 # Short delay for testing
|
||||
|
||||
client, channel, server = create_learner_service_stub(
|
||||
shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes
|
||||
)
|
||||
|
||||
# Add test parameters to the queue
|
||||
test_params = [b"param_batch_1", b"param_batch_2"]
|
||||
for param in test_params:
|
||||
parameters_queue.put(param)
|
||||
|
||||
# Start streaming parameters
|
||||
request = services_pb2.Empty()
|
||||
stream = client.StreamParameters(request)
|
||||
|
||||
# Collect streamed parameters and timestamps
|
||||
received_params = []
|
||||
timestamps = []
|
||||
|
||||
for response in stream:
|
||||
received_params.append(response.data)
|
||||
timestamps.append(time.time())
|
||||
|
||||
# We should receive one last item
|
||||
break
|
||||
|
||||
parameters_queue.put(b"param_batch_3")
|
||||
|
||||
for response in stream:
|
||||
received_params.append(response.data)
|
||||
timestamps.append(time.time())
|
||||
|
||||
# We should receive only one item
|
||||
break
|
||||
|
||||
shutdown_event.set()
|
||||
close_learner_service_stub(channel, server)
|
||||
|
||||
assert received_params == [b"param_batch_2", b"param_batch_3"]
|
||||
|
||||
# Check the time difference between the two sends
|
||||
time_diff = timestamps[1] - timestamps[0]
|
||||
# Check if the time difference is close to the expected push frequency
|
||||
assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1)
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_stream_parameters_with_shutdown():
|
||||
from lerobot.common.transport import services_pb2
|
||||
|
||||
"""Test StreamParameters handles shutdown gracefully."""
|
||||
shutdown_event = Event()
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
seconds_between_pushes = 0.1
|
||||
queue_get_timeout = 0.001
|
||||
|
||||
client, channel, server = create_learner_service_stub(
|
||||
shutdown_event,
|
||||
parameters_queue,
|
||||
transitions_queue,
|
||||
interactions_queue,
|
||||
seconds_between_pushes,
|
||||
queue_get_timeout=queue_get_timeout,
|
||||
)
|
||||
|
||||
test_params = [b"param_batch_1", b"stop", b"param_batch_3", b"param_batch_4"]
|
||||
|
||||
# create a thread that will put the parameters in the queue
|
||||
def producer():
|
||||
for param in test_params:
|
||||
parameters_queue.put(param)
|
||||
time.sleep(0.1)
|
||||
|
||||
producer_thread = threading.Thread(target=producer)
|
||||
producer_thread.start()
|
||||
|
||||
# Start streaming
|
||||
request = services_pb2.Empty()
|
||||
stream = client.StreamParameters(request)
|
||||
|
||||
# Collect streamed parameters
|
||||
received_params = []
|
||||
|
||||
for response in stream:
|
||||
received_params.append(response.data)
|
||||
|
||||
if response.data == b"stop":
|
||||
shutdown_event.set()
|
||||
|
||||
producer_thread.join()
|
||||
close_learner_service_stub(channel, server)
|
||||
|
||||
assert received_params == [b"param_batch_1", b"stop"]
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_stream_parameters_waits_and_retries_on_empty_queue():
|
||||
import threading
|
||||
import time
|
||||
|
||||
from lerobot.common.transport import services_pb2
|
||||
|
||||
"""Test that StreamParameters waits and retries when the queue is empty."""
|
||||
shutdown_event = Event()
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
seconds_between_pushes = 0.05
|
||||
queue_get_timeout = 0.01
|
||||
|
||||
client, channel, server = create_learner_service_stub(
|
||||
shutdown_event,
|
||||
parameters_queue,
|
||||
transitions_queue,
|
||||
interactions_queue,
|
||||
seconds_between_pushes,
|
||||
queue_get_timeout=queue_get_timeout,
|
||||
)
|
||||
|
||||
request = services_pb2.Empty()
|
||||
stream = client.StreamParameters(request)
|
||||
|
||||
received_params = []
|
||||
|
||||
def producer():
|
||||
# Let the consumer start and find an empty queue.
|
||||
# It will wait `seconds_between_pushes` (0.05s), then `get` will timeout after `queue_get_timeout` (0.01s).
|
||||
# Total time for the first empty loop is > 0.06s. We wait a bit longer to be safe.
|
||||
time.sleep(0.06)
|
||||
parameters_queue.put(b"param_after_wait")
|
||||
time.sleep(0.05)
|
||||
parameters_queue.put(b"param_after_wait_2")
|
||||
|
||||
producer_thread = threading.Thread(target=producer)
|
||||
producer_thread.start()
|
||||
|
||||
# The consumer will block here until the producer sends an item.
|
||||
for response in stream:
|
||||
received_params.append(response.data)
|
||||
if response.data == b"param_after_wait_2":
|
||||
break # We only need one item for this test.
|
||||
|
||||
shutdown_event.set()
|
||||
producer_thread.join()
|
||||
close_learner_service_stub(channel, server)
|
||||
|
||||
assert received_params == [b"param_after_wait", b"param_after_wait_2"]
|
||||
Reference in New Issue
Block a user