[HIL-SERL] Migrate threading to multiprocessing (#759)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
committed by
AdilZouitine
parent
38f5fa4523
commit
db78fee9de
@@ -1,23 +1,13 @@
|
||||
import hilserl_pb2 # type: ignore
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
import torch
|
||||
from torch import nn
|
||||
from threading import Lock, Event
|
||||
import logging
|
||||
import queue
|
||||
import io
|
||||
import pickle
|
||||
|
||||
from lerobot.scripts.server.buffer import (
|
||||
move_state_dict_to_device,
|
||||
bytes_buffer_size,
|
||||
state_to_bytes,
|
||||
)
|
||||
from multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks
|
||||
from lerobot.scripts.server.network_utils import send_bytes_in_chunks
|
||||
|
||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||
MAX_WORKERS = 10
|
||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||
STUTDOWN_TIMEOUT = 10
|
||||
|
||||
|
||||
@@ -25,89 +15,68 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||
def __init__(
|
||||
self,
|
||||
shutdown_event: Event,
|
||||
policy: nn.Module,
|
||||
policy_lock: Lock,
|
||||
parameters_queue: Queue,
|
||||
seconds_between_pushes: float,
|
||||
transition_queue: queue.Queue,
|
||||
interaction_message_queue: queue.Queue,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
):
|
||||
self.shutdown_event = shutdown_event
|
||||
self.policy = policy
|
||||
self.policy_lock = policy_lock
|
||||
self.parameters_queue = parameters_queue
|
||||
self.seconds_between_pushes = seconds_between_pushes
|
||||
self.transition_queue = transition_queue
|
||||
self.interaction_message_queue = interaction_message_queue
|
||||
|
||||
def _get_policy_state(self):
|
||||
with self.policy_lock:
|
||||
params_dict = self.policy.actor.state_dict()
|
||||
# if self.policy.config.vision_encoder_name is not None:
|
||||
# if self.policy.config.freeze_vision_encoder:
|
||||
# params_dict: dict[str, torch.Tensor] = {
|
||||
# k: v
|
||||
# for k, v in params_dict.items()
|
||||
# if not k.startswith("encoder.")
|
||||
# }
|
||||
# else:
|
||||
# raise NotImplementedError(
|
||||
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||
# )
|
||||
|
||||
return move_state_dict_to_device(params_dict, device="cpu")
|
||||
|
||||
def _send_bytes(self, buffer: bytes):
|
||||
size_in_bytes = bytes_buffer_size(buffer)
|
||||
|
||||
sent_bytes = 0
|
||||
|
||||
logging.info(f"Model state size {size_in_bytes/1024/1024} MB with")
|
||||
|
||||
while sent_bytes < size_in_bytes:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
|
||||
|
||||
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
|
||||
elif sent_bytes == 0:
|
||||
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
|
||||
|
||||
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
||||
chunk = buffer.read(size_to_read)
|
||||
|
||||
yield hilserl_pb2.Parameters(
|
||||
transfer_state=transfer_state, parameter_bytes=chunk
|
||||
)
|
||||
sent_bytes += size_to_read
|
||||
logging.info(
|
||||
f"[Learner] Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
|
||||
)
|
||||
|
||||
logging.info(f"[LEARNER] Published {sent_bytes/1024/1024} MB to the Actor")
|
||||
|
||||
def StreamParameters(self, request, context):
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
||||
|
||||
while not self.shutdown_event.is_set():
|
||||
logging.debug("[LEARNER] Push parameters to the Actor")
|
||||
state_dict = self._get_policy_state()
|
||||
logging.info("[LEARNER] Push parameters to the Actor")
|
||||
buffer = self.parameters_queue.get()
|
||||
|
||||
with state_to_bytes(state_dict) as buffer:
|
||||
yield from self._send_bytes(buffer)
|
||||
yield from send_bytes_in_chunks(
|
||||
buffer,
|
||||
hilserl_pb2.Parameters,
|
||||
log_prefix="[LEARNER] Sending parameters",
|
||||
silent=True,
|
||||
)
|
||||
|
||||
logging.info("[LEARNER] Parameters sent")
|
||||
|
||||
self.shutdown_event.wait(self.seconds_between_pushes)
|
||||
|
||||
def ReceiveTransitions(self, request_iterator, context):
|
||||
logging.info("[LEARNER] Stream parameters finished")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def SendTransitions(self, request_iterator, _context):
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
||||
|
||||
for request in request_iterator:
|
||||
logging.debug("[LEARNER] Received request")
|
||||
if request.HasField("transition"):
|
||||
buffer = io.BytesIO(request.transition.transition_bytes)
|
||||
transition = torch.load(buffer)
|
||||
self.transition_queue.put(transition)
|
||||
if request.HasField("interaction_message"):
|
||||
content = pickle.loads(
|
||||
request.interaction_message.interaction_message_bytes
|
||||
)
|
||||
self.interaction_message_queue.put(content)
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.transition_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] transitions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving transitions")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def SendInteractions(self, request_iterator, _context):
|
||||
# TODO: authorize the request
|
||||
logging.info(
|
||||
"[LEARNER] Received request to receive interactions from the Actor"
|
||||
)
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.interaction_message_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] interactions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving interactions")
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
def Ready(self, request, context):
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
Reference in New Issue
Block a user