[HIL-SERl PORT] Unit tests for Replay Buffer (#966)

This commit is contained in:
Eugene Mironov
2025-04-22 14:35:57 +07:00
committed by GitHub
parent dc726cb9a3
commit 0030ff3f74
2 changed files with 614 additions and 104 deletions

View File

@@ -15,7 +15,8 @@
# limitations under the License.
import functools
import io
import pickle
import pickle # nosec B403: Safe usage of pickle
from contextlib import suppress
from typing import Any, Callable, Optional, Sequence, TypedDict
import torch
@@ -113,7 +114,7 @@ def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer)
return torch.load(buffer) # nosec B614: Safe usage of torch.load
def python_object_to_bytes(python_object: Any) -> bytes:
@@ -123,13 +124,17 @@ def python_object_to_bytes(python_object: Any) -> bytes:
def bytes_to_python_object(buffer: bytes) -> Any:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return pickle.load(buffer)
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
# Add validation checks here
return obj
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer)
transitions = torch.load(buffer) # nosec B614: Safe usage of torch.load
# Add validation checks here
return transitions
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
@@ -201,6 +206,9 @@ class ReplayBuffer:
optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when
they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1].
"""
if capacity <= 0:
raise ValueError("Capacity must be greater than 0.")
self.capacity = capacity
self.device = device
self.storage_device = storage_device
@@ -215,6 +223,8 @@ class ReplayBuffer:
# If no state_keys provided, default to an empty list
self.state_keys = state_keys if state_keys is not None else []
self.image_augmentation_function = image_augmentation_function
if image_augmentation_function is None:
base_function = functools.partial(random_shift, pad=4)
self.image_augmentation_function = torch.compile(base_function)
@@ -418,11 +428,8 @@ class ReplayBuffer:
iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size)
# Yield all items from the iterator
try:
with suppress(StopIteration):
yield from iterator
except StopIteration:
# Just continue the outer loop to create a new iterator
pass
def _get_async_iterator(self, batch_size: int, queue_size: int = 2):
"""
@@ -586,10 +593,7 @@ class ReplayBuffer:
action = data["action"]
if action_mask is not None:
if action.dim() == 1:
action = action[action_mask]
else:
action = action[:, action_mask]
action = action[action_mask] if action.dim() == 1 else action[:, action_mask]
replay_buffer.add(
state=data["state"],
@@ -925,95 +929,3 @@ def concatenate_batch_transitions(
left_info[key] = right_info[key]
return left_batch_transitions
if __name__ == "__main__":
def test_load_dataset_with_complementary_info():
"""
Test loading a dataset with complementary_info into a ReplayBuffer.
The dataset 'aractingi/pick_lift_cube_two_cameras_gripper_penalty' contains
gripper_penalty values in complementary_info.
"""
import time
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
print("Loading dataset with complementary info...")
# Load a small subset of the dataset (first episode)
dataset = LeRobotDataset(
repo_id="aractingi/pick_lift_cube_two_cameras_gripper_penalty",
)
print(f"Dataset loaded with {len(dataset)} frames")
print(f"Dataset features: {list(dataset.features.keys())}")
# Check if dataset has complementary_info.gripper_penalty
sample = dataset[0]
complementary_info_keys = [key for key in sample if key.startswith("complementary_info")]
print(f"Complementary info keys: {complementary_info_keys}")
if "complementary_info.gripper_penalty" in sample:
print(f"Found gripper_penalty: {sample['complementary_info.gripper_penalty']}")
# Extract state keys for the buffer
state_keys = []
for key in sample:
if key.startswith("observation"):
state_keys.append(key)
print(f"Using state keys: {state_keys}")
# Create a replay buffer from the dataset
start_time = time.time()
buffer = ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset, state_keys=state_keys, use_drq=True, optimize_memory=False
)
load_time = time.time() - start_time
print(f"Loaded dataset into buffer in {load_time:.2f} seconds")
print(f"Buffer size: {len(buffer)}")
# Check if complementary_info was transferred correctly
print("Sampling from buffer to check complementary_info...")
batch = buffer.sample(batch_size=4)
if batch["complementary_info"] is not None:
print("Complementary info in batch:")
for key, value in batch["complementary_info"].items():
print(f" {key}: {type(value)}, shape: {value.shape if hasattr(value, 'shape') else 'N/A'}")
if key == "gripper_penalty":
print(f" Sample gripper_penalty values: {value[:5]}")
else:
print("No complementary_info found in batch")
# Now convert the buffer back to a LeRobotDataset
print("\nConverting buffer back to LeRobotDataset...")
start_time = time.time()
new_dataset = buffer.to_lerobot_dataset(
repo_id="test_dataset_from_buffer",
fps=dataset.fps,
root="./test_dataset_from_buffer",
task_name="test_conversion",
)
convert_time = time.time() - start_time
print(f"Converted buffer to dataset in {convert_time:.2f} seconds")
print(f"New dataset size: {len(new_dataset)} frames")
# Check if complementary_info was preserved
new_sample = new_dataset[0]
new_complementary_info_keys = [key for key in new_sample if key.startswith("complementary_info")]
print(f"New dataset complementary info keys: {new_complementary_info_keys}")
if "complementary_info.gripper_penalty" in new_sample:
print(f"Found gripper_penalty in new dataset: {new_sample['complementary_info.gripper_penalty']}")
# Compare original and new datasets
print("\nComparing original and new datasets:")
print(f"Original dataset frames: {len(dataset)}, New dataset frames: {len(new_dataset)}")
print(f"Original features: {list(dataset.features.keys())}")
print(f"New features: {list(new_dataset.features.keys())}")
return buffer, dataset, new_dataset
# Run the test
test_load_dataset_with_complementary_info()