forked from tangger/lerobot
[Async Inference] Add gRPC retry mechanism to Async client (#1485)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
@@ -63,12 +63,12 @@ from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.robots import so100_follower # noqa: F401
|
||||
from lerobot.scripts.rl import learner_service
|
||||
from lerobot.scripts.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
bytes_to_state_dict,
|
||||
grpc_channel_options,
|
||||
python_object_to_bytes,
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
@@ -399,8 +399,6 @@ def learner_service_client(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 50051,
|
||||
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||
import json
|
||||
|
||||
"""
|
||||
Returns a client for the learner service.
|
||||
|
||||
@@ -408,34 +406,9 @@ def learner_service_client(
|
||||
So we need to create only one client and reuse it.
|
||||
"""
|
||||
|
||||
service_config = {
|
||||
"methodConfig": [
|
||||
{
|
||||
"name": [{}], # Applies to ALL methods in ALL services
|
||||
"retryPolicy": {
|
||||
"maxAttempts": 5, # Max retries (total attempts = 5)
|
||||
"initialBackoff": "0.1s", # First retry after 0.1s
|
||||
"maxBackoff": "2s", # Max wait time between retries
|
||||
"backoffMultiplier": 2, # Exponential backoff factor
|
||||
"retryableStatusCodes": [
|
||||
"UNAVAILABLE",
|
||||
"DEADLINE_EXCEEDED",
|
||||
], # Retries on network failures
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
service_config_json = json.dumps(service_config)
|
||||
|
||||
channel = grpc.insecure_channel(
|
||||
f"{host}:{port}",
|
||||
options=[
|
||||
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.enable_retries", 1),
|
||||
("grpc.service_config", service_config_json),
|
||||
],
|
||||
grpc_channel_options(),
|
||||
)
|
||||
stub = services_pb2_grpc.LearnerServiceStub(channel)
|
||||
logging.info("[ACTOR] Learner service client created")
|
||||
|
||||
@@ -77,6 +77,7 @@ from lerobot.scripts.rl import learner_service
|
||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.transport import services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
MAX_MESSAGE_SIZE,
|
||||
bytes_to_python_object,
|
||||
bytes_to_transitions,
|
||||
state_to_bytes,
|
||||
@@ -658,8 +659,8 @@ def start_learner(
|
||||
server = grpc.server(
|
||||
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
|
||||
options=[
|
||||
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.max_receive_message_length", MAX_MESSAGE_SIZE),
|
||||
("grpc.max_send_message_length", MAX_MESSAGE_SIZE),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
||||
from lerobot.utils.queue import get_last_item_from_queue
|
||||
|
||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||
SHUTDOWN_TIMEOUT = 10
|
||||
|
||||
|
||||
@@ -76,6 +76,7 @@ from lerobot.transport import (
|
||||
async_inference_pb2, # type: ignore
|
||||
async_inference_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import grpc_channel_options
|
||||
|
||||
|
||||
class RobotClient:
|
||||
@@ -113,7 +114,9 @@ class RobotClient:
|
||||
config.actions_per_chunk,
|
||||
config.policy_device,
|
||||
)
|
||||
self.channel = grpc.insecure_channel(self.server_address)
|
||||
self.channel = grpc.insecure_channel(
|
||||
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
|
||||
)
|
||||
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import pickle # nosec B403: Safe usage for internal serialization only
|
||||
from multiprocessing import Event, Queue
|
||||
@@ -27,6 +28,7 @@ from lerobot.transport import services_pb2
|
||||
from lerobot.utils.transition import Transition
|
||||
|
||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||
|
||||
|
||||
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
||||
@@ -139,3 +141,42 @@ def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
||||
buffer = io.BytesIO()
|
||||
torch.save(transitions, buffer)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def grpc_channel_options(
|
||||
max_receive_message_length: int = MAX_MESSAGE_SIZE,
|
||||
max_send_message_length: int = MAX_MESSAGE_SIZE,
|
||||
enable_retries: bool = True,
|
||||
initial_backoff: str = "0.1s",
|
||||
max_attempts: int = 5,
|
||||
backoff_multiplier: float = 2,
|
||||
max_backoff: str = "2s",
|
||||
):
|
||||
service_config = {
|
||||
"methodConfig": [
|
||||
{
|
||||
"name": [{}], # Applies to ALL methods in ALL services
|
||||
"retryPolicy": {
|
||||
"maxAttempts": max_attempts, # Max retries (total attempts = 5)
|
||||
"initialBackoff": initial_backoff, # First retry after 0.1s
|
||||
"maxBackoff": max_backoff, # Max wait time between retries
|
||||
"backoffMultiplier": backoff_multiplier, # Exponential backoff factor
|
||||
"retryableStatusCodes": [
|
||||
"UNAVAILABLE",
|
||||
"DEADLINE_EXCEEDED",
|
||||
], # Retries on network failures
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
service_config_json = json.dumps(service_config)
|
||||
|
||||
retries_option = 1 if enable_retries else 0
|
||||
|
||||
return [
|
||||
("grpc.max_receive_message_length", max_receive_message_length),
|
||||
("grpc.max_send_message_length", max_send_message_length),
|
||||
("grpc.enable_retries", retries_option),
|
||||
("grpc.service_config", service_config_json),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user