[HIL-SERL] Migrate threading to multiprocessing (#759)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Eugene Mironov
2025-03-05 17:19:31 +07:00
committed by AdilZouitine
parent 38f5fa4523
commit db78fee9de
14 changed files with 900 additions and 492 deletions

View File

@@ -1,23 +1,13 @@
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
import torch
from torch import nn
from threading import Lock, Event
import logging
import queue
import io
import pickle
from lerobot.scripts.server.buffer import (
move_state_dict_to_device,
bytes_buffer_size,
state_to_bytes,
)
from multiprocessing import Event, Queue
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks
from lerobot.scripts.server.network_utils import send_bytes_in_chunks
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
MAX_WORKERS = 10
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
STUTDOWN_TIMEOUT = 10
@@ -25,89 +15,68 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
def __init__(
self,
shutdown_event: Event,
policy: nn.Module,
policy_lock: Lock,
parameters_queue: Queue,
seconds_between_pushes: float,
transition_queue: queue.Queue,
interaction_message_queue: queue.Queue,
transition_queue: Queue,
interaction_message_queue: Queue,
):
self.shutdown_event = shutdown_event
self.policy = policy
self.policy_lock = policy_lock
self.parameters_queue = parameters_queue
self.seconds_between_pushes = seconds_between_pushes
self.transition_queue = transition_queue
self.interaction_message_queue = interaction_message_queue
def _get_policy_state(self):
with self.policy_lock:
params_dict = self.policy.actor.state_dict()
# if self.policy.config.vision_encoder_name is not None:
# if self.policy.config.freeze_vision_encoder:
# params_dict: dict[str, torch.Tensor] = {
# k: v
# for k, v in params_dict.items()
# if not k.startswith("encoder.")
# }
# else:
# raise NotImplementedError(
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
# )
return move_state_dict_to_device(params_dict, device="cpu")
def _send_bytes(self, buffer: bytes):
size_in_bytes = bytes_buffer_size(buffer)
sent_bytes = 0
logging.info(f"Model state size {size_in_bytes/1024/1024} MB with")
while sent_bytes < size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
elif sent_bytes == 0:
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
chunk = buffer.read(size_to_read)
yield hilserl_pb2.Parameters(
transfer_state=transfer_state, parameter_bytes=chunk
)
sent_bytes += size_to_read
logging.info(
f"[Learner] Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
)
logging.info(f"[LEARNER] Published {sent_bytes/1024/1024} MB to the Actor")
def StreamParameters(self, request, context):
# TODO: authorize the request
logging.info("[LEARNER] Received request to stream parameters from the Actor")
while not self.shutdown_event.is_set():
logging.debug("[LEARNER] Push parameters to the Actor")
state_dict = self._get_policy_state()
logging.info("[LEARNER] Push parameters to the Actor")
buffer = self.parameters_queue.get()
with state_to_bytes(state_dict) as buffer:
yield from self._send_bytes(buffer)
yield from send_bytes_in_chunks(
buffer,
hilserl_pb2.Parameters,
log_prefix="[LEARNER] Sending parameters",
silent=True,
)
logging.info("[LEARNER] Parameters sent")
self.shutdown_event.wait(self.seconds_between_pushes)
def ReceiveTransitions(self, request_iterator, context):
logging.info("[LEARNER] Stream parameters finished")
return hilserl_pb2.Empty()
def SendTransitions(self, request_iterator, _context):
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive transitions from the Actor")
for request in request_iterator:
logging.debug("[LEARNER] Received request")
if request.HasField("transition"):
buffer = io.BytesIO(request.transition.transition_bytes)
transition = torch.load(buffer)
self.transition_queue.put(transition)
if request.HasField("interaction_message"):
content = pickle.loads(
request.interaction_message.interaction_message_bytes
)
self.interaction_message_queue.put(content)
receive_bytes_in_chunks(
request_iterator,
self.transition_queue,
self.shutdown_event,
log_prefix="[LEARNER] transitions",
)
logging.debug("[LEARNER] Finished receiving transitions")
return hilserl_pb2.Empty()
def SendInteractions(self, request_iterator, _context):
# TODO: authorize the request
logging.info(
"[LEARNER] Received request to receive interactions from the Actor"
)
receive_bytes_in_chunks(
request_iterator,
self.interaction_message_queue,
self.shutdown_event,
log_prefix="[LEARNER] interactions",
)
logging.debug("[LEARNER] Finished receiving interactions")
return hilserl_pb2.Empty()
def Ready(self, request, context):
return hilserl_pb2.Empty()