forked from tangger/lerobot
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
2abbd60a0d
commit
0ea27704f6
@@ -14,16 +14,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
from typing import Any, Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import io
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
import os
|
||||
import pickle
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
@@ -45,38 +45,27 @@ class BatchTransition(TypedDict):
|
||||
truncated: torch.Tensor
|
||||
|
||||
|
||||
def move_transition_to_device(
|
||||
transition: Transition, device: str = "cpu"
|
||||
) -> Transition:
|
||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||
# Move state tensors to CPU
|
||||
device = torch.device(device)
|
||||
transition["state"] = {
|
||||
key: val.to(device, non_blocking=device.type == "cuda")
|
||||
for key, val in transition["state"].items()
|
||||
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
|
||||
}
|
||||
|
||||
# Move action to CPU
|
||||
transition["action"] = transition["action"].to(
|
||||
device, non_blocking=device.type == "cuda"
|
||||
)
|
||||
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
# No need to move reward or done, as they are float and bool
|
||||
|
||||
# No need to move reward or done, as they are float and bool
|
||||
if isinstance(transition["reward"], torch.Tensor):
|
||||
transition["reward"] = transition["reward"].to(
|
||||
device=device, non_blocking=device.type == "cuda"
|
||||
)
|
||||
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
|
||||
|
||||
if isinstance(transition["done"], torch.Tensor):
|
||||
transition["done"] = transition["done"].to(
|
||||
device, non_blocking=device.type == "cuda"
|
||||
)
|
||||
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
if isinstance(transition["truncated"], torch.Tensor):
|
||||
transition["truncated"] = transition["truncated"].to(
|
||||
device, non_blocking=device.type == "cuda"
|
||||
)
|
||||
transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
# Move next_state tensors to CPU
|
||||
transition["next_state"] = {
|
||||
@@ -100,10 +89,7 @@ def move_state_dict_to_device(state_dict, device="cpu"):
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
return state_dict.to(device)
|
||||
elif isinstance(state_dict, dict):
|
||||
return {
|
||||
k: move_state_dict_to_device(v, device=device)
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
||||
elif isinstance(state_dict, list):
|
||||
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
||||
elif isinstance(state_dict, tuple):
|
||||
@@ -174,9 +160,7 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[
|
||||
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
|
||||
]
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
@@ -223,9 +207,7 @@ class ReplayBuffer:
|
||||
self.optimize_memory = optimize_memory
|
||||
|
||||
# Track episode boundaries for memory optimization
|
||||
self.episode_ends = torch.zeros(
|
||||
capacity, dtype=torch.bool, device=storage_device
|
||||
)
|
||||
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
|
||||
|
||||
# If no state_keys provided, default to an empty list
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
@@ -246,9 +228,7 @@ class ReplayBuffer:
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
self.actions = torch.empty(
|
||||
(self.capacity, *action_shape), device=self.storage_device
|
||||
)
|
||||
self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device)
|
||||
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
|
||||
|
||||
if not self.optimize_memory:
|
||||
@@ -262,12 +242,8 @@ class ReplayBuffer:
|
||||
# Just create a reference to states for consistent API
|
||||
self.next_states = self.states # Just a reference for API consistency
|
||||
|
||||
self.dones = torch.empty(
|
||||
(self.capacity,), dtype=torch.bool, device=self.storage_device
|
||||
)
|
||||
self.truncateds = torch.empty(
|
||||
(self.capacity,), dtype=torch.bool, device=self.storage_device
|
||||
)
|
||||
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.initialized = True
|
||||
|
||||
def __len__(self):
|
||||
@@ -294,9 +270,7 @@ class ReplayBuffer:
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Only store next_states if not optimizing memory
|
||||
self.next_states[key][self.position].copy_(
|
||||
next_state[key].squeeze(dim=0)
|
||||
)
|
||||
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
|
||||
|
||||
self.actions[self.position].copy_(action.squeeze(dim=0))
|
||||
self.rewards[self.position] = reward
|
||||
@@ -309,23 +283,15 @@ class ReplayBuffer:
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
if not self.initialized:
|
||||
raise RuntimeError(
|
||||
"Cannot sample from an empty buffer. Add transitions first."
|
||||
)
|
||||
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
|
||||
|
||||
batch_size = min(batch_size, self.size)
|
||||
|
||||
# Random indices for sampling - create on the same device as storage
|
||||
idx = torch.randint(
|
||||
low=0, high=self.size, size=(batch_size,), device=self.storage_device
|
||||
)
|
||||
idx = torch.randint(low=0, high=self.size, size=(batch_size,), device=self.storage_device)
|
||||
|
||||
# Identify image keys that need augmentation
|
||||
image_keys = (
|
||||
[k for k in self.states if k.startswith("observation.image")]
|
||||
if self.use_drq
|
||||
else []
|
||||
)
|
||||
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
|
||||
|
||||
# Create batched state and next_state
|
||||
batch_state = {}
|
||||
@@ -358,13 +324,9 @@ class ReplayBuffer:
|
||||
# Split the augmented images back to their sources
|
||||
for i, key in enumerate(image_keys):
|
||||
# State images are at even indices (0, 2, 4...)
|
||||
batch_state[key] = augmented_images[
|
||||
i * 2 * batch_size : (i * 2 + 1) * batch_size
|
||||
]
|
||||
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
|
||||
# Next state images are at odd indices (1, 3, 5...)
|
||||
batch_next_state[key] = augmented_images[
|
||||
(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size
|
||||
]
|
||||
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
|
||||
|
||||
# Sample other tensors
|
||||
batch_actions = self.actions[idx].to(self.device)
|
||||
@@ -434,16 +396,12 @@ class ReplayBuffer:
|
||||
)
|
||||
|
||||
# Convert dataset to transitions
|
||||
list_transition = cls._lerobotdataset_to_transitions(
|
||||
dataset=lerobot_dataset, state_keys=state_keys
|
||||
)
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
|
||||
# Initialize the buffer with the first transition to set up storage tensors
|
||||
if list_transition:
|
||||
first_transition = list_transition[0]
|
||||
first_state = {
|
||||
k: v.to(device) for k, v in first_transition["state"].items()
|
||||
}
|
||||
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
||||
first_action = first_transition["action"].to(device)
|
||||
|
||||
# Apply action mask/delta if needed
|
||||
@@ -541,9 +499,7 @@ class ReplayBuffer:
|
||||
|
||||
# Convert transitions into episodes and frames
|
||||
episode_index = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||
episode_index=episode_index
|
||||
)
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index)
|
||||
|
||||
frame_idx_in_episode = 0
|
||||
for idx in range(self.size):
|
||||
@@ -557,12 +513,8 @@ class ReplayBuffer:
|
||||
|
||||
# Fill action, reward, done
|
||||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||
frame_dict["next.reward"] = torch.tensor(
|
||||
[self.rewards[actual_idx]], dtype=torch.float32
|
||||
).cpu()
|
||||
frame_dict["next.done"] = torch.tensor(
|
||||
[self.dones[actual_idx]], dtype=torch.bool
|
||||
).cpu()
|
||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
|
||||
# Add to the dataset's buffer
|
||||
lerobot_dataset.add_frame(frame_dict)
|
||||
@@ -619,9 +571,7 @@ class ReplayBuffer:
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
"""
|
||||
if state_keys is None:
|
||||
raise ValueError(
|
||||
"State keys must be provided when converting LeRobotDataset to Transitions."
|
||||
)
|
||||
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
|
||||
|
||||
transitions = []
|
||||
num_frames = len(dataset)
|
||||
@@ -632,9 +582,7 @@ class ReplayBuffer:
|
||||
|
||||
# If not, we need to infer it from episode boundaries
|
||||
if not has_done_key:
|
||||
print(
|
||||
"'next.done' key not found in dataset. Inferring from episode boundaries..."
|
||||
)
|
||||
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
|
||||
|
||||
for i in tqdm(range(num_frames)):
|
||||
current_sample = dataset[i]
|
||||
@@ -886,8 +834,7 @@ if __name__ == "__main__":
|
||||
# We need to be careful because we don't know the original index
|
||||
# So we check if the increment is roughly 0.01
|
||||
next_state_check = (
|
||||
abs(next_state_sig - state_sig - 0.01) < 1e-4
|
||||
or abs(next_state_sig - state_sig) < 1e-4
|
||||
abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
|
||||
)
|
||||
|
||||
# Count correct relationships
|
||||
@@ -901,17 +848,11 @@ if __name__ == "__main__":
|
||||
total_checks += 3
|
||||
|
||||
alignment_accuracy = 100.0 * correct_relationships / total_checks
|
||||
print(
|
||||
f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%"
|
||||
)
|
||||
print(f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%")
|
||||
if alignment_accuracy > 99.0:
|
||||
print(
|
||||
"✅ All relationships verified! Buffer maintains correct temporal relationships."
|
||||
)
|
||||
print("✅ All relationships verified! Buffer maintains correct temporal relationships.")
|
||||
else:
|
||||
print(
|
||||
"⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues."
|
||||
)
|
||||
print("⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues.")
|
||||
|
||||
# Print some debug information about failures
|
||||
print("\nDebug information for failed checks:")
|
||||
@@ -973,18 +914,14 @@ if __name__ == "__main__":
|
||||
|
||||
# Verify consistency before and after conversion
|
||||
original_states = batch["state"]["observation.image"].mean().item()
|
||||
reconverted_states = (
|
||||
reconverted_batch["state"]["observation.image"].mean().item()
|
||||
)
|
||||
reconverted_states = reconverted_batch["state"]["observation.image"].mean().item()
|
||||
print(f"Original buffer state mean: {original_states:.4f}")
|
||||
print(f"Reconverted buffer state mean: {reconverted_states:.4f}")
|
||||
|
||||
if abs(original_states - reconverted_states) < 1.0:
|
||||
print("Values are reasonably similar - conversion works as expected")
|
||||
else:
|
||||
print(
|
||||
"WARNING: Significant difference between original and reconverted values"
|
||||
)
|
||||
print("WARNING: Significant difference between original and reconverted values")
|
||||
|
||||
print("\nAll previous tests completed!")
|
||||
|
||||
@@ -1093,15 +1030,11 @@ if __name__ == "__main__":
|
||||
all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device)
|
||||
|
||||
# Get state tensors
|
||||
batch_state = {
|
||||
"value": test_buffer.states["value"][all_indices].to(test_buffer.device)
|
||||
}
|
||||
batch_state = {"value": test_buffer.states["value"][all_indices].to(test_buffer.device)}
|
||||
|
||||
# Get next_state using memory-optimized approach (simply index+1)
|
||||
next_indices = (all_indices + 1) % test_buffer.capacity
|
||||
batch_next_state = {
|
||||
"value": test_buffer.states["value"][next_indices].to(test_buffer.device)
|
||||
}
|
||||
batch_next_state = {"value": test_buffer.states["value"][next_indices].to(test_buffer.device)}
|
||||
|
||||
# Get other tensors
|
||||
batch_dones = test_buffer.dones[all_indices].to(test_buffer.device)
|
||||
@@ -1121,9 +1054,7 @@ if __name__ == "__main__":
|
||||
print("- We always use the next state in the buffer (index+1) as next_state")
|
||||
print("- For terminal states, this means using the first state of the next episode")
|
||||
print("- This is a common tradeoff in RL implementations for memory efficiency")
|
||||
print(
|
||||
"- Since we track done flags, the algorithm can handle these transitions correctly"
|
||||
)
|
||||
print("- Since we track done flags, the algorithm can handle these transitions correctly")
|
||||
|
||||
# Test random sampling
|
||||
print("\nVerifying random sampling with simplified memory optimization...")
|
||||
@@ -1137,23 +1068,19 @@ if __name__ == "__main__":
|
||||
# Print a few samples
|
||||
print("Random samples - State, Next State, Done (First 10):")
|
||||
for i in range(10):
|
||||
print(
|
||||
f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}"
|
||||
)
|
||||
print(f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}")
|
||||
|
||||
# Calculate memory savings
|
||||
# Assume optimized_buffer and standard_buffer have already been initialized and filled
|
||||
std_mem = (
|
||||
sum(
|
||||
standard_buffer.states[key].nelement()
|
||||
* standard_buffer.states[key].element_size()
|
||||
standard_buffer.states[key].nelement() * standard_buffer.states[key].element_size()
|
||||
for key in standard_buffer.states
|
||||
)
|
||||
* 2
|
||||
)
|
||||
opt_mem = sum(
|
||||
optimized_buffer.states[key].nelement()
|
||||
* optimized_buffer.states[key].element_size()
|
||||
optimized_buffer.states[key].nelement() * optimized_buffer.states[key].element_size()
|
||||
for key in optimized_buffer.states
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user