forked from tangger/lerobot
96 lines
3.3 KiB
Python
96 lines
3.3 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team.
|
|
# All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import io
|
|
import logging
|
|
from multiprocessing import Event, Queue
|
|
from typing import Any
|
|
|
|
from lerobot.scripts.server import hilserl_pb2
|
|
|
|
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
|
|
|
|
|
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
|
buffer.seek(0, io.SEEK_END)
|
|
result = buffer.tell()
|
|
buffer.seek(0)
|
|
return result
|
|
|
|
|
|
def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True):
|
|
buffer = io.BytesIO(buffer)
|
|
size_in_bytes = bytes_buffer_size(buffer)
|
|
|
|
sent_bytes = 0
|
|
|
|
logging_method = logging.info if not silent else logging.debug
|
|
|
|
logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with")
|
|
|
|
while sent_bytes < size_in_bytes:
|
|
transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE
|
|
|
|
if sent_bytes + CHUNK_SIZE >= size_in_bytes:
|
|
transfer_state = hilserl_pb2.TransferState.TRANSFER_END
|
|
elif sent_bytes == 0:
|
|
transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN
|
|
|
|
size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes)
|
|
chunk = buffer.read(size_to_read)
|
|
|
|
yield message_class(transfer_state=transfer_state, data=chunk)
|
|
sent_bytes += size_to_read
|
|
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
|
|
|
|
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
|
|
|
|
|
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""):
|
|
bytes_buffer = io.BytesIO()
|
|
step = 0
|
|
|
|
logging.info(f"{log_prefix} Starting receiver")
|
|
for item in iterator:
|
|
logging.debug(f"{log_prefix} Received item")
|
|
if shutdown_event.is_set():
|
|
logging.info(f"{log_prefix} Shutting down receiver")
|
|
return
|
|
|
|
if item.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN:
|
|
bytes_buffer.seek(0)
|
|
bytes_buffer.truncate(0)
|
|
bytes_buffer.write(item.data)
|
|
logging.debug(f"{log_prefix} Received data at step 0")
|
|
step = 0
|
|
continue
|
|
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE:
|
|
bytes_buffer.write(item.data)
|
|
step += 1
|
|
logging.debug(f"{log_prefix} Received data at step {step}")
|
|
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
|
|
bytes_buffer.write(item.data)
|
|
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
|
|
|
queue.put(bytes_buffer.getvalue())
|
|
|
|
bytes_buffer.seek(0)
|
|
bytes_buffer.truncate(0)
|
|
step = 0
|
|
|
|
logging.debug(f"{log_prefix} Queue updated")
|