forked from tangger/lerobot
[Async Inference] Add gRPC retry mechanism to Async client (#1485)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
@@ -63,12 +63,12 @@ from lerobot.configs.train import TrainRLServerPipelineConfig
|
|||||||
from lerobot.policies.factory import make_policy
|
from lerobot.policies.factory import make_policy
|
||||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||||
from lerobot.robots import so100_follower # noqa: F401
|
from lerobot.robots import so100_follower # noqa: F401
|
||||||
from lerobot.scripts.rl import learner_service
|
|
||||||
from lerobot.scripts.rl.gym_manipulator import make_robot_env
|
from lerobot.scripts.rl.gym_manipulator import make_robot_env
|
||||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||||
from lerobot.transport.utils import (
|
from lerobot.transport.utils import (
|
||||||
bytes_to_state_dict,
|
bytes_to_state_dict,
|
||||||
|
grpc_channel_options,
|
||||||
python_object_to_bytes,
|
python_object_to_bytes,
|
||||||
receive_bytes_in_chunks,
|
receive_bytes_in_chunks,
|
||||||
send_bytes_in_chunks,
|
send_bytes_in_chunks,
|
||||||
@@ -399,8 +399,6 @@ def learner_service_client(
|
|||||||
host: str = "127.0.0.1",
|
host: str = "127.0.0.1",
|
||||||
port: int = 50051,
|
port: int = 50051,
|
||||||
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||||
import json
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Returns a client for the learner service.
|
Returns a client for the learner service.
|
||||||
|
|
||||||
@@ -408,34 +406,9 @@ def learner_service_client(
|
|||||||
So we need to create only one client and reuse it.
|
So we need to create only one client and reuse it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
service_config = {
|
|
||||||
"methodConfig": [
|
|
||||||
{
|
|
||||||
"name": [{}], # Applies to ALL methods in ALL services
|
|
||||||
"retryPolicy": {
|
|
||||||
"maxAttempts": 5, # Max retries (total attempts = 5)
|
|
||||||
"initialBackoff": "0.1s", # First retry after 0.1s
|
|
||||||
"maxBackoff": "2s", # Max wait time between retries
|
|
||||||
"backoffMultiplier": 2, # Exponential backoff factor
|
|
||||||
"retryableStatusCodes": [
|
|
||||||
"UNAVAILABLE",
|
|
||||||
"DEADLINE_EXCEEDED",
|
|
||||||
], # Retries on network failures
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
service_config_json = json.dumps(service_config)
|
|
||||||
|
|
||||||
channel = grpc.insecure_channel(
|
channel = grpc.insecure_channel(
|
||||||
f"{host}:{port}",
|
f"{host}:{port}",
|
||||||
options=[
|
grpc_channel_options(),
|
||||||
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
|
||||||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
|
||||||
("grpc.enable_retries", 1),
|
|
||||||
("grpc.service_config", service_config_json),
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
stub = services_pb2_grpc.LearnerServiceStub(channel)
|
stub = services_pb2_grpc.LearnerServiceStub(channel)
|
||||||
logging.info("[ACTOR] Learner service client created")
|
logging.info("[ACTOR] Learner service client created")
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ from lerobot.scripts.rl import learner_service
|
|||||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||||
from lerobot.transport import services_pb2_grpc
|
from lerobot.transport import services_pb2_grpc
|
||||||
from lerobot.transport.utils import (
|
from lerobot.transport.utils import (
|
||||||
|
MAX_MESSAGE_SIZE,
|
||||||
bytes_to_python_object,
|
bytes_to_python_object,
|
||||||
bytes_to_transitions,
|
bytes_to_transitions,
|
||||||
state_to_bytes,
|
state_to_bytes,
|
||||||
@@ -658,8 +659,8 @@ def start_learner(
|
|||||||
server = grpc.server(
|
server = grpc.server(
|
||||||
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
|
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
|
||||||
options=[
|
options=[
|
||||||
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
("grpc.max_receive_message_length", MAX_MESSAGE_SIZE),
|
||||||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
("grpc.max_send_message_length", MAX_MESSAGE_SIZE),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from lerobot.transport import services_pb2, services_pb2_grpc
|
|||||||
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
||||||
from lerobot.utils.queue import get_last_item_from_queue
|
from lerobot.utils.queue import get_last_item_from_queue
|
||||||
|
|
||||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
|
||||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||||
SHUTDOWN_TIMEOUT = 10
|
SHUTDOWN_TIMEOUT = 10
|
||||||
|
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ from lerobot.transport import (
|
|||||||
async_inference_pb2, # type: ignore
|
async_inference_pb2, # type: ignore
|
||||||
async_inference_pb2_grpc, # type: ignore
|
async_inference_pb2_grpc, # type: ignore
|
||||||
)
|
)
|
||||||
|
from lerobot.transport.utils import grpc_channel_options
|
||||||
|
|
||||||
|
|
||||||
class RobotClient:
|
class RobotClient:
|
||||||
@@ -113,7 +114,9 @@ class RobotClient:
|
|||||||
config.actions_per_chunk,
|
config.actions_per_chunk,
|
||||||
config.policy_device,
|
config.policy_device,
|
||||||
)
|
)
|
||||||
self.channel = grpc.insecure_channel(self.server_address)
|
self.channel = grpc.insecure_channel(
|
||||||
|
self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s")
|
||||||
|
)
|
||||||
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
|
||||||
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
|
self.logger.info(f"Initializing client to connect to server at {self.server_address}")
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import pickle # nosec B403: Safe usage for internal serialization only
|
import pickle # nosec B403: Safe usage for internal serialization only
|
||||||
from multiprocessing import Event, Queue
|
from multiprocessing import Event, Queue
|
||||||
@@ -27,6 +28,7 @@ from lerobot.transport import services_pb2
|
|||||||
from lerobot.utils.transition import Transition
|
from lerobot.utils.transition import Transition
|
||||||
|
|
||||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||||
|
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||||
|
|
||||||
|
|
||||||
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
||||||
@@ -139,3 +141,42 @@ def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
|||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
torch.save(transitions, buffer)
|
torch.save(transitions, buffer)
|
||||||
return buffer.getvalue()
|
return buffer.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def grpc_channel_options(
|
||||||
|
max_receive_message_length: int = MAX_MESSAGE_SIZE,
|
||||||
|
max_send_message_length: int = MAX_MESSAGE_SIZE,
|
||||||
|
enable_retries: bool = True,
|
||||||
|
initial_backoff: str = "0.1s",
|
||||||
|
max_attempts: int = 5,
|
||||||
|
backoff_multiplier: float = 2,
|
||||||
|
max_backoff: str = "2s",
|
||||||
|
):
|
||||||
|
service_config = {
|
||||||
|
"methodConfig": [
|
||||||
|
{
|
||||||
|
"name": [{}], # Applies to ALL methods in ALL services
|
||||||
|
"retryPolicy": {
|
||||||
|
"maxAttempts": max_attempts, # Max retries (total attempts = 5)
|
||||||
|
"initialBackoff": initial_backoff, # First retry after 0.1s
|
||||||
|
"maxBackoff": max_backoff, # Max wait time between retries
|
||||||
|
"backoffMultiplier": backoff_multiplier, # Exponential backoff factor
|
||||||
|
"retryableStatusCodes": [
|
||||||
|
"UNAVAILABLE",
|
||||||
|
"DEADLINE_EXCEEDED",
|
||||||
|
], # Retries on network failures
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
service_config_json = json.dumps(service_config)
|
||||||
|
|
||||||
|
retries_option = 1 if enable_retries else 0
|
||||||
|
|
||||||
|
return [
|
||||||
|
("grpc.max_receive_message_length", max_receive_message_length),
|
||||||
|
("grpc.max_send_message_length", max_send_message_length),
|
||||||
|
("grpc.enable_retries", retries_option),
|
||||||
|
("grpc.service_config", service_config_json),
|
||||||
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user