forked from tangger/lerobot
[HIL-SERL] Migrate threading to multiprocessing (#759)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
committed by
Michel Aractingi
parent
85fe8a3f4e
commit
b6a2200983
@@ -23,6 +23,7 @@ from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
import os
|
||||
import pickle
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
@@ -91,7 +92,7 @@ def move_transition_to_device(
|
||||
return transition
|
||||
|
||||
|
||||
def move_state_dict_to_device(state_dict, device):
|
||||
def move_state_dict_to_device(state_dict, device="cpu"):
|
||||
"""
|
||||
Recursively move all tensors in a (potentially) nested
|
||||
dict/list/tuple structure to the CPU.
|
||||
@@ -111,20 +112,41 @@ def move_state_dict_to_device(state_dict, device):
|
||||
return state_dict
|
||||
|
||||
|
||||
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> io.BytesIO:
|
||||
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
||||
"""Convert model state dict to flat array for transmission"""
|
||||
buffer = io.BytesIO()
|
||||
|
||||
torch.save(state_dict, buffer)
|
||||
|
||||
return buffer
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
||||
buffer.seek(0, io.SEEK_END)
|
||||
result = buffer.tell()
|
||||
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return result
|
||||
return torch.load(buffer)
|
||||
|
||||
|
||||
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||
return pickle.dumps(python_object)
|
||||
|
||||
|
||||
def bytes_to_python_object(buffer: bytes) -> Any:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return pickle.load(buffer)
|
||||
|
||||
|
||||
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return torch.load(buffer)
|
||||
|
||||
|
||||
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
||||
buffer = io.BytesIO()
|
||||
torch.save(transitions, buffer)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user