[Port HIL-SERL] Adjust Actor-Learner architecture & clean up dependency management for HIL-SERL (#722)

This commit is contained in:
Eugene Mironov
2025-02-21 16:29:00 +07:00
committed by AdilZouitine
parent 150def839c
commit d48161da1b
17 changed files with 1949 additions and 475 deletions

View File

@@ -17,6 +17,7 @@ import functools
import random
from typing import Any, Callable, Optional, Sequence, TypedDict
import io
import torch
import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
@@ -41,24 +42,33 @@ class BatchTransition(TypedDict):
done: torch.Tensor
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
def move_transition_to_device(
transition: Transition, device: str = "cpu"
) -> Transition:
# Move state tensors to CPU
device = torch.device(device)
transition["state"] = {
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
key: val.to(device, non_blocking=device.type == "cuda")
for key, val in transition["state"].items()
}
# Move action to CPU
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
transition["action"] = transition["action"].to(
device, non_blocking=device.type == "cuda"
)
# No need to move reward or done, as they are float and bool
# No need to move reward or done, as they are float and bool
if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
transition["reward"] = transition["reward"].to(
device=device, non_blocking=device.type == "cuda"
)
if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
transition["done"] = transition["done"].to(
device, non_blocking=device.type == "cuda"
)
# Move next_state tensors to CPU
transition["next_state"] = {
@@ -82,7 +92,10 @@ def move_state_dict_to_device(state_dict, device):
if isinstance(state_dict, torch.Tensor):
return state_dict.to(device)
elif isinstance(state_dict, dict):
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
return {
k: move_state_dict_to_device(v, device=device)
for k, v in state_dict.items()
}
elif isinstance(state_dict, list):
return [move_state_dict_to_device(v, device=device) for v in state_dict]
elif isinstance(state_dict, tuple):
@@ -91,6 +104,22 @@ def move_state_dict_to_device(state_dict, device):
return state_dict
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> io.BytesIO:
"""Convert model state dict to flat array for transmission"""
buffer = io.BytesIO()
torch.save(state_dict, buffer)
return buffer
def bytes_buffer_size(buffer: io.BytesIO) -> int:
buffer.seek(0, io.SEEK_END)
result = buffer.tell()
buffer.seek(0)
return result
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
"""
Perform a per-image random crop over a batch of images in a vectorized way.
@@ -116,7 +145,9 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
cropped_hwcn = images_hwcn[
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
]
# cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
@@ -179,7 +210,9 @@ class ReplayBuffer:
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
# Move tensors to the storage device
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
next_state = {
key: tensor.to(self.storage_device) for key, tensor in next_state.items()
}
action = action.to(self.storage_device)
# if complementary_info is not None:
# complementary_info = {
@@ -234,7 +267,9 @@ class ReplayBuffer:
)
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys)
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
list_transition = cls._lerobotdataset_to_transitions(
dataset=lerobot_dataset, state_keys=state_keys
)
# Fill the replay buffer with the lerobot dataset transitions
for data in list_transition:
for k, v in data.items():
@@ -295,7 +330,9 @@ class ReplayBuffer:
# If not provided, you can either raise an error or define a default:
if state_keys is None:
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
raise ValueError(
"You must provide a list of keys in `state_keys` that define your 'state'."
)
transitions: list[Transition] = []
num_frames = len(dataset)
@@ -350,33 +387,37 @@ class ReplayBuffer:
# -- Build batched states --
batch_state = {}
for key in self.state_keys:
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
batch_state[key] = torch.cat(
[t["state"][key] for t in list_of_transitions], dim=0
).to(self.device)
if key.startswith("observation.image") and self.use_drq:
batch_state[key] = self.image_augmentation_function(batch_state[key])
# -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
# -- Build batched rewards --
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
self.device
)
# -- Build batched rewards --
batch_rewards = torch.tensor(
[t["reward"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# -- Build batched next states --
batch_next_state = {}
for key in self.state_keys:
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
batch_next_state[key] = torch.cat(
[t["next_state"][key] for t in list_of_transitions], dim=0
).to(self.device)
if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
batch_next_state[key] = self.image_augmentation_function(
batch_next_state[key]
)
# -- Build batched dones --
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
batch_dones = torch.tensor(
[t["done"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# Return a BatchTransition typed dict
return BatchTransition(
@@ -433,7 +474,9 @@ class ReplayBuffer:
# Add state keys
for key in self.state_keys:
sample_val = first_transition["state"][key].squeeze(dim=0) # Remove batch dimension
sample_val = first_transition["state"][key].squeeze(
dim=0
) # Remove batch dimension
if not isinstance(sample_val, torch.Tensor):
raise ValueError(
f"State key '{key}' is not a torch.Tensor. Please ensure your states are stored as torch.Tensors."
@@ -465,7 +508,9 @@ class ReplayBuffer:
# We detect episode boundaries by `done == True`.
# --------------------------------------------------------------------------------------------
episode_index = 0
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
episode_index
)
frame_idx_in_episode = 0
for global_frame_idx, transition in enumerate(self.memory):
@@ -476,16 +521,24 @@ class ReplayBuffer:
# Expand dimension to match what the dataset expects (the dataset wants the raw shape)
# We assume your buffer has shape [C, H, W] (if image) or [D] if vector
# This is typically already correct, but if needed you can reshape below.
frame_dict[key] = transition["state"][key].cpu().squeeze(dim=0) # Remove batch dimension
frame_dict[key] = (
transition["state"][key].cpu().squeeze(dim=0)
) # Remove batch dimension
# Fill action, reward, done
# Make sure they are shape (X,) or (X,Y,...) as needed.
frame_dict["action"] = transition["action"].cpu().squeeze(dim=0) # Remove batch dimension
frame_dict["action"] = (
transition["action"].cpu().squeeze(dim=0)
) # Remove batch dimension
frame_dict["next.reward"] = (
torch.tensor([transition["reward"]], dtype=torch.float32).cpu().squeeze(dim=0)
torch.tensor([transition["reward"]], dtype=torch.float32)
.cpu()
.squeeze(dim=0)
)
frame_dict["next.done"] = (
torch.tensor([transition["done"]], dtype=torch.bool).cpu().squeeze(dim=0)
torch.tensor([transition["done"]], dtype=torch.bool)
.cpu()
.squeeze(dim=0)
)
# Add to the dataset's buffer
lerobot_dataset.add_frame(frame_dict)
@@ -499,7 +552,9 @@ class ReplayBuffer:
episode_index += 1
frame_idx_in_episode = 0
# Start a new buffer for the next episode
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
episode_index
)
# We are done adding frames
# If the last transition wasn't done=True, we still have an open buffer with frames.
@@ -541,7 +596,13 @@ def concatenate_batch_transitions(
) -> BatchTransition:
"""NOTE: Be careful it change the left_batch_transitions in place"""
left_batch_transitions["state"] = {
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
key: torch.cat(
[
left_batch_transitions["state"][key],
right_batch_transition["state"][key],
],
dim=0,
)
for key in left_batch_transitions["state"]
}
left_batch_transitions["action"] = torch.cat(
@@ -552,7 +613,11 @@ def concatenate_batch_transitions(
)
left_batch_transitions["next_state"] = {
key: torch.cat(
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
[
left_batch_transitions["next_state"][key],
right_batch_transition["next_state"][key],
],
dim=0,
)
for key in left_batch_transitions["next_state"]
}