forked from tangger/lerobot
[Port HIL-SERL] Adjust Actor-Learner architecture & clean up dependency management for HIL-SERL (#722)
This commit is contained in:
committed by
AdilZouitine
parent
150def839c
commit
d48161da1b
@@ -17,6 +17,7 @@ import functools
|
||||
import random
|
||||
from typing import Any, Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import io
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
@@ -41,24 +42,33 @@ class BatchTransition(TypedDict):
|
||||
done: torch.Tensor
|
||||
|
||||
|
||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||
def move_transition_to_device(
|
||||
transition: Transition, device: str = "cpu"
|
||||
) -> Transition:
|
||||
# Move state tensors to CPU
|
||||
device = torch.device(device)
|
||||
transition["state"] = {
|
||||
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
|
||||
key: val.to(device, non_blocking=device.type == "cuda")
|
||||
for key, val in transition["state"].items()
|
||||
}
|
||||
|
||||
# Move action to CPU
|
||||
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
|
||||
transition["action"] = transition["action"].to(
|
||||
device, non_blocking=device.type == "cuda"
|
||||
)
|
||||
|
||||
# No need to move reward or done, as they are float and bool
|
||||
|
||||
# No need to move reward or done, as they are float and bool
|
||||
if isinstance(transition["reward"], torch.Tensor):
|
||||
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
|
||||
transition["reward"] = transition["reward"].to(
|
||||
device=device, non_blocking=device.type == "cuda"
|
||||
)
|
||||
|
||||
if isinstance(transition["done"], torch.Tensor):
|
||||
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
|
||||
transition["done"] = transition["done"].to(
|
||||
device, non_blocking=device.type == "cuda"
|
||||
)
|
||||
|
||||
# Move next_state tensors to CPU
|
||||
transition["next_state"] = {
|
||||
@@ -82,7 +92,10 @@ def move_state_dict_to_device(state_dict, device):
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
return state_dict.to(device)
|
||||
elif isinstance(state_dict, dict):
|
||||
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
||||
return {
|
||||
k: move_state_dict_to_device(v, device=device)
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
elif isinstance(state_dict, list):
|
||||
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
||||
elif isinstance(state_dict, tuple):
|
||||
@@ -91,6 +104,22 @@ def move_state_dict_to_device(state_dict, device):
|
||||
return state_dict
|
||||
|
||||
|
||||
def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> io.BytesIO:
|
||||
"""Convert model state dict to flat array for transmission"""
|
||||
buffer = io.BytesIO()
|
||||
|
||||
torch.save(state_dict, buffer)
|
||||
|
||||
return buffer
|
||||
|
||||
|
||||
def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
||||
buffer.seek(0, io.SEEK_END)
|
||||
result = buffer.tell()
|
||||
buffer.seek(0)
|
||||
return result
|
||||
|
||||
|
||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||
"""
|
||||
Perform a per-image random crop over a batch of images in a vectorized way.
|
||||
@@ -116,7 +145,9 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
cropped_hwcn = images_hwcn[
|
||||
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
|
||||
]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
@@ -179,7 +210,9 @@ class ReplayBuffer:
|
||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||
# Move tensors to the storage device
|
||||
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
|
||||
next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
|
||||
next_state = {
|
||||
key: tensor.to(self.storage_device) for key, tensor in next_state.items()
|
||||
}
|
||||
action = action.to(self.storage_device)
|
||||
# if complementary_info is not None:
|
||||
# complementary_info = {
|
||||
@@ -234,7 +267,9 @@ class ReplayBuffer:
|
||||
)
|
||||
|
||||
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys)
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
list_transition = cls._lerobotdataset_to_transitions(
|
||||
dataset=lerobot_dataset, state_keys=state_keys
|
||||
)
|
||||
# Fill the replay buffer with the lerobot dataset transitions
|
||||
for data in list_transition:
|
||||
for k, v in data.items():
|
||||
@@ -295,7 +330,9 @@ class ReplayBuffer:
|
||||
|
||||
# If not provided, you can either raise an error or define a default:
|
||||
if state_keys is None:
|
||||
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
|
||||
raise ValueError(
|
||||
"You must provide a list of keys in `state_keys` that define your 'state'."
|
||||
)
|
||||
|
||||
transitions: list[Transition] = []
|
||||
num_frames = len(dataset)
|
||||
@@ -350,33 +387,37 @@ class ReplayBuffer:
|
||||
# -- Build batched states --
|
||||
batch_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
batch_state[key] = torch.cat(
|
||||
[t["state"][key] for t in list_of_transitions], dim=0
|
||||
).to(self.device)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_state[key] = self.image_augmentation_function(batch_state[key])
|
||||
|
||||
# -- Build batched actions --
|
||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
|
||||
|
||||
# -- Build batched rewards --
|
||||
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# -- Build batched rewards --
|
||||
batch_rewards = torch.tensor(
|
||||
[t["reward"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
# -- Build batched next states --
|
||||
batch_next_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
batch_next_state[key] = torch.cat(
|
||||
[t["next_state"][key] for t in list_of_transitions], dim=0
|
||||
).to(self.device)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
|
||||
batch_next_state[key] = self.image_augmentation_function(
|
||||
batch_next_state[key]
|
||||
)
|
||||
|
||||
# -- Build batched dones --
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
batch_dones = torch.tensor(
|
||||
[t["done"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
# Return a BatchTransition typed dict
|
||||
return BatchTransition(
|
||||
@@ -433,7 +474,9 @@ class ReplayBuffer:
|
||||
|
||||
# Add state keys
|
||||
for key in self.state_keys:
|
||||
sample_val = first_transition["state"][key].squeeze(dim=0) # Remove batch dimension
|
||||
sample_val = first_transition["state"][key].squeeze(
|
||||
dim=0
|
||||
) # Remove batch dimension
|
||||
if not isinstance(sample_val, torch.Tensor):
|
||||
raise ValueError(
|
||||
f"State key '{key}' is not a torch.Tensor. Please ensure your states are stored as torch.Tensors."
|
||||
@@ -465,7 +508,9 @@ class ReplayBuffer:
|
||||
# We detect episode boundaries by `done == True`.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
episode_index = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||
episode_index
|
||||
)
|
||||
|
||||
frame_idx_in_episode = 0
|
||||
for global_frame_idx, transition in enumerate(self.memory):
|
||||
@@ -476,16 +521,24 @@ class ReplayBuffer:
|
||||
# Expand dimension to match what the dataset expects (the dataset wants the raw shape)
|
||||
# We assume your buffer has shape [C, H, W] (if image) or [D] if vector
|
||||
# This is typically already correct, but if needed you can reshape below.
|
||||
frame_dict[key] = transition["state"][key].cpu().squeeze(dim=0) # Remove batch dimension
|
||||
frame_dict[key] = (
|
||||
transition["state"][key].cpu().squeeze(dim=0)
|
||||
) # Remove batch dimension
|
||||
|
||||
# Fill action, reward, done
|
||||
# Make sure they are shape (X,) or (X,Y,...) as needed.
|
||||
frame_dict["action"] = transition["action"].cpu().squeeze(dim=0) # Remove batch dimension
|
||||
frame_dict["action"] = (
|
||||
transition["action"].cpu().squeeze(dim=0)
|
||||
) # Remove batch dimension
|
||||
frame_dict["next.reward"] = (
|
||||
torch.tensor([transition["reward"]], dtype=torch.float32).cpu().squeeze(dim=0)
|
||||
torch.tensor([transition["reward"]], dtype=torch.float32)
|
||||
.cpu()
|
||||
.squeeze(dim=0)
|
||||
)
|
||||
frame_dict["next.done"] = (
|
||||
torch.tensor([transition["done"]], dtype=torch.bool).cpu().squeeze(dim=0)
|
||||
torch.tensor([transition["done"]], dtype=torch.bool)
|
||||
.cpu()
|
||||
.squeeze(dim=0)
|
||||
)
|
||||
# Add to the dataset's buffer
|
||||
lerobot_dataset.add_frame(frame_dict)
|
||||
@@ -499,7 +552,9 @@ class ReplayBuffer:
|
||||
episode_index += 1
|
||||
frame_idx_in_episode = 0
|
||||
# Start a new buffer for the next episode
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||
episode_index
|
||||
)
|
||||
|
||||
# We are done adding frames
|
||||
# If the last transition wasn't done=True, we still have an open buffer with frames.
|
||||
@@ -541,7 +596,13 @@ def concatenate_batch_transitions(
|
||||
) -> BatchTransition:
|
||||
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||
left_batch_transitions["state"] = {
|
||||
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
|
||||
key: torch.cat(
|
||||
[
|
||||
left_batch_transitions["state"][key],
|
||||
right_batch_transition["state"][key],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["state"]
|
||||
}
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
@@ -552,7 +613,11 @@ def concatenate_batch_transitions(
|
||||
)
|
||||
left_batch_transitions["next_state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
|
||||
[
|
||||
left_batch_transitions["next_state"][key],
|
||||
right_batch_transition["next_state"][key],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["next_state"]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user