[HIL-SERL] Migrate threading to multiprocessing (#759)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Eugene Mironov
2025-03-05 17:19:31 +07:00
committed by Michel Aractingi
parent 85fe8a3f4e
commit b6a2200983
14 changed files with 900 additions and 492 deletions

View File

@@ -23,6 +23,7 @@ from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import os
import pickle
class Transition(TypedDict):
@@ -91,7 +92,7 @@ def move_transition_to_device(
return transition
def move_state_dict_to_device(state_dict, device):
def move_state_dict_to_device(state_dict, device="cpu"):
"""
Recursively move all tensors in a (potentially) nested
dict/list/tuple structure to the CPU.
@@ -111,20 +112,41 @@ def move_state_dict_to_device(state_dict, device):
return state_dict
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> io.BytesIO:
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
"""Convert model state dict to flat array for transmission"""
buffer = io.BytesIO()
torch.save(state_dict, buffer)
return buffer
return buffer.getvalue()
def bytes_buffer_size(buffer: io.BytesIO) -> int:
buffer.seek(0, io.SEEK_END)
result = buffer.tell()
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return result
return torch.load(buffer)
def python_object_to_bytes(python_object: Any) -> bytes:
return pickle.dumps(python_object)
def bytes_to_python_object(buffer: bytes) -> Any:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return pickle.load(buffer)
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer)
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
buffer = io.BytesIO()
torch.save(transitions, buffer)
return buffer.getvalue()
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor: