[Async Inference] Add gRPC retry mechanism to Async client (#1485)

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Eugene Mironov
2025-07-16 21:13:01 +07:00
committed by GitHub
parent dfb1571bcf
commit 816034948a
5 changed files with 50 additions and 33 deletions

View File

@@ -63,12 +63,12 @@ from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.factory import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.robots import so100_follower # noqa: F401
from lerobot.scripts.rl import learner_service
from lerobot.scripts.rl.gym_manipulator import make_robot_env
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
@@ -399,8 +399,6 @@ def learner_service_client(
host: str = "127.0.0.1",
port: int = 50051,
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
import json
"""
Returns a client for the learner service.
@@ -408,34 +406,9 @@ def learner_service_client(
So we need to create only one client and reuse it.
"""
service_config = {
"methodConfig": [
{
"name": [{}], # Applies to ALL methods in ALL services
"retryPolicy": {
"maxAttempts": 5, # Max retries (total attempts = 5)
"initialBackoff": "0.1s", # First retry after 0.1s
"maxBackoff": "2s", # Max wait time between retries
"backoffMultiplier": 2, # Exponential backoff factor
"retryableStatusCodes": [
"UNAVAILABLE",
"DEADLINE_EXCEEDED",
], # Retries on network failures
},
}
]
}
service_config_json = json.dumps(service_config)
channel = grpc.insecure_channel(
f"{host}:{port}",
options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.enable_retries", 1),
("grpc.service_config", service_config_json),
],
grpc_channel_options(),
)
stub = services_pb2_grpc.LearnerServiceStub(channel)
logging.info("[ACTOR] Learner service client created")

View File

@@ -77,6 +77,7 @@ from lerobot.scripts.rl import learner_service
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
from lerobot.transport import services_pb2_grpc
from lerobot.transport.utils import (
MAX_MESSAGE_SIZE,
bytes_to_python_object,
bytes_to_transitions,
state_to_bytes,
@@ -658,8 +659,8 @@ def start_learner(
server = grpc.server(
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_receive_message_length", MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", MAX_MESSAGE_SIZE),
],
)

View File

@@ -23,7 +23,6 @@ from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
from lerobot.utils.queue import get_last_item_from_queue
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10

View File

@@ -76,6 +76,7 @@ from lerobot.transport import (
async_inference_pb2, # type: ignore
async_inference_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import grpc_channel_options
class RobotClient:
@@ -113,7 +114,9 @@ class RobotClient:
config.actions_per_chunk,
config.policy_device,
)
self.channel = grpc.insecure_channel(self.server_address)
self.channel = grpc.insecure_channel(
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
)
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
self.logger.info(f"Initializing client to connect to server at {self.server_address}")

View File

@@ -16,6 +16,7 @@
# limitations under the License.
import io
import json
import logging
import pickle # nosec B403: Safe usage for internal serialization only
from multiprocessing import Event, Queue
@@ -27,6 +28,7 @@ from lerobot.transport import services_pb2
from lerobot.utils.transition import Transition
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
def bytes_buffer_size(buffer: io.BytesIO) -> int:
@@ -139,3 +141,42 @@ def transitions_to_bytes(transitions: list[Transition]) -> bytes:
buffer = io.BytesIO()
torch.save(transitions, buffer)
return buffer.getvalue()
def grpc_channel_options(
max_receive_message_length: int = MAX_MESSAGE_SIZE,
max_send_message_length: int = MAX_MESSAGE_SIZE,
enable_retries: bool = True,
initial_backoff: str = "0.1s",
max_attempts: int = 5,
backoff_multiplier: float = 2,
max_backoff: str = "2s",
):
service_config = {
"methodConfig": [
{
"name": [{}], # Applies to ALL methods in ALL services
"retryPolicy": {
"maxAttempts": max_attempts, # Max retries (total attempts = 5)
"initialBackoff": initial_backoff, # First retry after 0.1s
"maxBackoff": max_backoff, # Max wait time between retries
"backoffMultiplier": backoff_multiplier, # Exponential backoff factor
"retryableStatusCodes": [
"UNAVAILABLE",
"DEADLINE_EXCEEDED",
], # Retries on network failures
},
}
]
}
service_config_json = json.dumps(service_config)
retries_option = 1 if enable_retries else 0
return [
("grpc.max_receive_message_length", max_receive_message_length),
("grpc.max_send_message_length", max_send_message_length),
("grpc.enable_retries", retries_option),
("grpc.service_config", service_config_json),
]