[HIL-SERl PORT] Unit tests for Replay Buffer (#966)
This commit is contained in:
@@ -15,7 +15,8 @@
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import io
|
||||
import pickle
|
||||
import pickle # nosec B403: Safe usage of pickle
|
||||
from contextlib import suppress
|
||||
from typing import Any, Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import torch
|
||||
@@ -113,7 +114,7 @@ def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
||||
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return torch.load(buffer)
|
||||
return torch.load(buffer) # nosec B614: Safe usage of torch.load
|
||||
|
||||
|
||||
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||
@@ -123,13 +124,17 @@ def python_object_to_bytes(python_object: Any) -> bytes:
|
||||
def bytes_to_python_object(buffer: bytes) -> Any:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return pickle.load(buffer)
|
||||
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
|
||||
# Add validation checks here
|
||||
return obj
|
||||
|
||||
|
||||
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return torch.load(buffer)
|
||||
transitions = torch.load(buffer) # nosec B614: Safe usage of torch.load
|
||||
# Add validation checks here
|
||||
return transitions
|
||||
|
||||
|
||||
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
||||
@@ -201,6 +206,9 @@ class ReplayBuffer:
|
||||
optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when
|
||||
they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1].
|
||||
"""
|
||||
if capacity <= 0:
|
||||
raise ValueError("Capacity must be greater than 0.")
|
||||
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.storage_device = storage_device
|
||||
@@ -215,6 +223,8 @@ class ReplayBuffer:
|
||||
# If no state_keys provided, default to an empty list
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
|
||||
self.image_augmentation_function = image_augmentation_function
|
||||
|
||||
if image_augmentation_function is None:
|
||||
base_function = functools.partial(random_shift, pad=4)
|
||||
self.image_augmentation_function = torch.compile(base_function)
|
||||
@@ -418,11 +428,8 @@ class ReplayBuffer:
|
||||
iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size)
|
||||
|
||||
# Yield all items from the iterator
|
||||
try:
|
||||
with suppress(StopIteration):
|
||||
yield from iterator
|
||||
except StopIteration:
|
||||
# Just continue the outer loop to create a new iterator
|
||||
pass
|
||||
|
||||
def _get_async_iterator(self, batch_size: int, queue_size: int = 2):
|
||||
"""
|
||||
@@ -586,10 +593,7 @@ class ReplayBuffer:
|
||||
|
||||
action = data["action"]
|
||||
if action_mask is not None:
|
||||
if action.dim() == 1:
|
||||
action = action[action_mask]
|
||||
else:
|
||||
action = action[:, action_mask]
|
||||
action = action[action_mask] if action.dim() == 1 else action[:, action_mask]
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
@@ -925,95 +929,3 @@ def concatenate_batch_transitions(
|
||||
left_info[key] = right_info[key]
|
||||
|
||||
return left_batch_transitions
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def test_load_dataset_with_complementary_info():
|
||||
"""
|
||||
Test loading a dataset with complementary_info into a ReplayBuffer.
|
||||
The dataset 'aractingi/pick_lift_cube_two_cameras_gripper_penalty' contains
|
||||
gripper_penalty values in complementary_info.
|
||||
"""
|
||||
import time
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
print("Loading dataset with complementary info...")
|
||||
# Load a small subset of the dataset (first episode)
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="aractingi/pick_lift_cube_two_cameras_gripper_penalty",
|
||||
)
|
||||
|
||||
print(f"Dataset loaded with {len(dataset)} frames")
|
||||
print(f"Dataset features: {list(dataset.features.keys())}")
|
||||
|
||||
# Check if dataset has complementary_info.gripper_penalty
|
||||
sample = dataset[0]
|
||||
complementary_info_keys = [key for key in sample if key.startswith("complementary_info")]
|
||||
print(f"Complementary info keys: {complementary_info_keys}")
|
||||
|
||||
if "complementary_info.gripper_penalty" in sample:
|
||||
print(f"Found gripper_penalty: {sample['complementary_info.gripper_penalty']}")
|
||||
|
||||
# Extract state keys for the buffer
|
||||
state_keys = []
|
||||
for key in sample:
|
||||
if key.startswith("observation"):
|
||||
state_keys.append(key)
|
||||
|
||||
print(f"Using state keys: {state_keys}")
|
||||
|
||||
# Create a replay buffer from the dataset
|
||||
start_time = time.time()
|
||||
buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset, state_keys=state_keys, use_drq=True, optimize_memory=False
|
||||
)
|
||||
load_time = time.time() - start_time
|
||||
print(f"Loaded dataset into buffer in {load_time:.2f} seconds")
|
||||
print(f"Buffer size: {len(buffer)}")
|
||||
|
||||
# Check if complementary_info was transferred correctly
|
||||
print("Sampling from buffer to check complementary_info...")
|
||||
batch = buffer.sample(batch_size=4)
|
||||
|
||||
if batch["complementary_info"] is not None:
|
||||
print("Complementary info in batch:")
|
||||
for key, value in batch["complementary_info"].items():
|
||||
print(f" {key}: {type(value)}, shape: {value.shape if hasattr(value, 'shape') else 'N/A'}")
|
||||
if key == "gripper_penalty":
|
||||
print(f" Sample gripper_penalty values: {value[:5]}")
|
||||
else:
|
||||
print("No complementary_info found in batch")
|
||||
|
||||
# Now convert the buffer back to a LeRobotDataset
|
||||
print("\nConverting buffer back to LeRobotDataset...")
|
||||
start_time = time.time()
|
||||
new_dataset = buffer.to_lerobot_dataset(
|
||||
repo_id="test_dataset_from_buffer",
|
||||
fps=dataset.fps,
|
||||
root="./test_dataset_from_buffer",
|
||||
task_name="test_conversion",
|
||||
)
|
||||
convert_time = time.time() - start_time
|
||||
print(f"Converted buffer to dataset in {convert_time:.2f} seconds")
|
||||
print(f"New dataset size: {len(new_dataset)} frames")
|
||||
|
||||
# Check if complementary_info was preserved
|
||||
new_sample = new_dataset[0]
|
||||
new_complementary_info_keys = [key for key in new_sample if key.startswith("complementary_info")]
|
||||
print(f"New dataset complementary info keys: {new_complementary_info_keys}")
|
||||
|
||||
if "complementary_info.gripper_penalty" in new_sample:
|
||||
print(f"Found gripper_penalty in new dataset: {new_sample['complementary_info.gripper_penalty']}")
|
||||
|
||||
# Compare original and new datasets
|
||||
print("\nComparing original and new datasets:")
|
||||
print(f"Original dataset frames: {len(dataset)}, New dataset frames: {len(new_dataset)}")
|
||||
print(f"Original features: {list(dataset.features.keys())}")
|
||||
print(f"New features: {list(new_dataset.features.keys())}")
|
||||
|
||||
return buffer, dataset, new_dataset
|
||||
|
||||
# Run the test
|
||||
test_load_dataset_with_complementary_info()
|
||||
|
||||
Reference in New Issue
Block a user