forked from tangger/lerobot
[HIL-SERl PORT] Unit tests for Replay Buffer (#966)
This commit is contained in:
@@ -15,7 +15,8 @@
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import io
|
||||
import pickle
|
||||
import pickle # nosec B403: Safe usage of pickle
|
||||
from contextlib import suppress
|
||||
from typing import Any, Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import torch
|
||||
@@ -113,7 +114,7 @@ def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
|
||||
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return torch.load(buffer)
|
||||
return torch.load(buffer) # nosec B614: Safe usage of torch.load
|
||||
|
||||
|
||||
def python_object_to_bytes(python_object: Any) -> bytes:
|
||||
@@ -123,13 +124,17 @@ def python_object_to_bytes(python_object: Any) -> bytes:
|
||||
def bytes_to_python_object(buffer: bytes) -> Any:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return pickle.load(buffer)
|
||||
obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load
|
||||
# Add validation checks here
|
||||
return obj
|
||||
|
||||
|
||||
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
|
||||
buffer = io.BytesIO(buffer)
|
||||
buffer.seek(0)
|
||||
return torch.load(buffer)
|
||||
transitions = torch.load(buffer) # nosec B614: Safe usage of torch.load
|
||||
# Add validation checks here
|
||||
return transitions
|
||||
|
||||
|
||||
def transitions_to_bytes(transitions: list[Transition]) -> bytes:
|
||||
@@ -201,6 +206,9 @@ class ReplayBuffer:
|
||||
optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when
|
||||
they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1].
|
||||
"""
|
||||
if capacity <= 0:
|
||||
raise ValueError("Capacity must be greater than 0.")
|
||||
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.storage_device = storage_device
|
||||
@@ -215,6 +223,8 @@ class ReplayBuffer:
|
||||
# If no state_keys provided, default to an empty list
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
|
||||
self.image_augmentation_function = image_augmentation_function
|
||||
|
||||
if image_augmentation_function is None:
|
||||
base_function = functools.partial(random_shift, pad=4)
|
||||
self.image_augmentation_function = torch.compile(base_function)
|
||||
@@ -418,11 +428,8 @@ class ReplayBuffer:
|
||||
iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size)
|
||||
|
||||
# Yield all items from the iterator
|
||||
try:
|
||||
with suppress(StopIteration):
|
||||
yield from iterator
|
||||
except StopIteration:
|
||||
# Just continue the outer loop to create a new iterator
|
||||
pass
|
||||
|
||||
def _get_async_iterator(self, batch_size: int, queue_size: int = 2):
|
||||
"""
|
||||
@@ -586,10 +593,7 @@ class ReplayBuffer:
|
||||
|
||||
action = data["action"]
|
||||
if action_mask is not None:
|
||||
if action.dim() == 1:
|
||||
action = action[action_mask]
|
||||
else:
|
||||
action = action[:, action_mask]
|
||||
action = action[action_mask] if action.dim() == 1 else action[:, action_mask]
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
@@ -925,95 +929,3 @@ def concatenate_batch_transitions(
|
||||
left_info[key] = right_info[key]
|
||||
|
||||
return left_batch_transitions
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def test_load_dataset_with_complementary_info():
|
||||
"""
|
||||
Test loading a dataset with complementary_info into a ReplayBuffer.
|
||||
The dataset 'aractingi/pick_lift_cube_two_cameras_gripper_penalty' contains
|
||||
gripper_penalty values in complementary_info.
|
||||
"""
|
||||
import time
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
print("Loading dataset with complementary info...")
|
||||
# Load a small subset of the dataset (first episode)
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="aractingi/pick_lift_cube_two_cameras_gripper_penalty",
|
||||
)
|
||||
|
||||
print(f"Dataset loaded with {len(dataset)} frames")
|
||||
print(f"Dataset features: {list(dataset.features.keys())}")
|
||||
|
||||
# Check if dataset has complementary_info.gripper_penalty
|
||||
sample = dataset[0]
|
||||
complementary_info_keys = [key for key in sample if key.startswith("complementary_info")]
|
||||
print(f"Complementary info keys: {complementary_info_keys}")
|
||||
|
||||
if "complementary_info.gripper_penalty" in sample:
|
||||
print(f"Found gripper_penalty: {sample['complementary_info.gripper_penalty']}")
|
||||
|
||||
# Extract state keys for the buffer
|
||||
state_keys = []
|
||||
for key in sample:
|
||||
if key.startswith("observation"):
|
||||
state_keys.append(key)
|
||||
|
||||
print(f"Using state keys: {state_keys}")
|
||||
|
||||
# Create a replay buffer from the dataset
|
||||
start_time = time.time()
|
||||
buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset, state_keys=state_keys, use_drq=True, optimize_memory=False
|
||||
)
|
||||
load_time = time.time() - start_time
|
||||
print(f"Loaded dataset into buffer in {load_time:.2f} seconds")
|
||||
print(f"Buffer size: {len(buffer)}")
|
||||
|
||||
# Check if complementary_info was transferred correctly
|
||||
print("Sampling from buffer to check complementary_info...")
|
||||
batch = buffer.sample(batch_size=4)
|
||||
|
||||
if batch["complementary_info"] is not None:
|
||||
print("Complementary info in batch:")
|
||||
for key, value in batch["complementary_info"].items():
|
||||
print(f" {key}: {type(value)}, shape: {value.shape if hasattr(value, 'shape') else 'N/A'}")
|
||||
if key == "gripper_penalty":
|
||||
print(f" Sample gripper_penalty values: {value[:5]}")
|
||||
else:
|
||||
print("No complementary_info found in batch")
|
||||
|
||||
# Now convert the buffer back to a LeRobotDataset
|
||||
print("\nConverting buffer back to LeRobotDataset...")
|
||||
start_time = time.time()
|
||||
new_dataset = buffer.to_lerobot_dataset(
|
||||
repo_id="test_dataset_from_buffer",
|
||||
fps=dataset.fps,
|
||||
root="./test_dataset_from_buffer",
|
||||
task_name="test_conversion",
|
||||
)
|
||||
convert_time = time.time() - start_time
|
||||
print(f"Converted buffer to dataset in {convert_time:.2f} seconds")
|
||||
print(f"New dataset size: {len(new_dataset)} frames")
|
||||
|
||||
# Check if complementary_info was preserved
|
||||
new_sample = new_dataset[0]
|
||||
new_complementary_info_keys = [key for key in new_sample if key.startswith("complementary_info")]
|
||||
print(f"New dataset complementary info keys: {new_complementary_info_keys}")
|
||||
|
||||
if "complementary_info.gripper_penalty" in new_sample:
|
||||
print(f"Found gripper_penalty in new dataset: {new_sample['complementary_info.gripper_penalty']}")
|
||||
|
||||
# Compare original and new datasets
|
||||
print("\nComparing original and new datasets:")
|
||||
print(f"Original dataset frames: {len(dataset)}, New dataset frames: {len(new_dataset)}")
|
||||
print(f"Original features: {list(dataset.features.keys())}")
|
||||
print(f"New features: {list(new_dataset.features.keys())}")
|
||||
|
||||
return buffer, dataset, new_dataset
|
||||
|
||||
# Run the test
|
||||
test_load_dataset_with_complementary_info()
|
||||
|
||||
598
tests/server/test_replay_buffer.py
Normal file
598
tests/server/test_replay_buffer.py
Normal file
@@ -0,0 +1,598 @@
|
||||
import sys
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.scripts.server.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def state_dims() -> list[str]:
|
||||
return ["observation.image", "observation.state"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def replay_buffer() -> ReplayBuffer:
|
||||
return create_empty_replay_buffer()
|
||||
|
||||
|
||||
def clone_state(state: dict) -> dict:
|
||||
return {k: v.clone() for k, v in state.items()}
|
||||
|
||||
|
||||
def create_empty_replay_buffer(
|
||||
optimize_memory: bool = False,
|
||||
use_drq: bool = False,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
) -> ReplayBuffer:
|
||||
buffer_capacity = 10
|
||||
device = "cpu"
|
||||
return ReplayBuffer(
|
||||
buffer_capacity,
|
||||
device,
|
||||
state_dims(),
|
||||
optimize_memory=optimize_memory,
|
||||
use_drq=use_drq,
|
||||
image_augmentation_function=image_augmentation_function,
|
||||
)
|
||||
|
||||
|
||||
def create_random_image() -> torch.Tensor:
|
||||
return torch.rand(3, 84, 84)
|
||||
|
||||
|
||||
def create_dummy_transition() -> dict:
|
||||
return {
|
||||
"observation.image": create_random_image(),
|
||||
"action": torch.randn(4),
|
||||
"reward": torch.tensor(1.0),
|
||||
"observation.state": torch.randn(
|
||||
10,
|
||||
),
|
||||
"done": torch.tensor(False),
|
||||
"truncated": torch.tensor(False),
|
||||
"complementary_info": {},
|
||||
}
|
||||
|
||||
|
||||
def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayBuffer]:
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer = create_empty_replay_buffer()
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
|
||||
|
||||
root = tmp_path / "test"
|
||||
return (replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root), replay_buffer)
|
||||
|
||||
|
||||
def create_dummy_state() -> dict:
|
||||
return {
|
||||
"observation.image": create_random_image(),
|
||||
"observation.state": torch.randn(
|
||||
10,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_tensor_memory_consumption(tensor):
|
||||
return tensor.nelement() * tensor.element_size()
|
||||
|
||||
|
||||
def get_tensors_memory_consumption(obj, visited_addresses):
|
||||
total_size = 0
|
||||
|
||||
address = id(obj)
|
||||
if address in visited_addresses:
|
||||
return 0
|
||||
|
||||
visited_addresses.add(address)
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return get_tensor_memory_consumption(obj)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
total_size += get_tensors_memory_consumption(item, visited_addresses)
|
||||
elif isinstance(obj, dict):
|
||||
for value in obj.values():
|
||||
total_size += get_tensors_memory_consumption(value, visited_addresses)
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# It's an object, we need to get the size of the attributes
|
||||
for _, attr in vars(obj).items():
|
||||
total_size += get_tensors_memory_consumption(attr, visited_addresses)
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def get_object_memory(obj):
|
||||
# Track visited addresses to avoid infinite loops
|
||||
# and cases when two properties point to the same object
|
||||
visited_addresses = set()
|
||||
|
||||
# Get the size of the object in bytes
|
||||
total_size = sys.getsizeof(obj)
|
||||
|
||||
# Get the size of the tensor attributes
|
||||
total_size += get_tensors_memory_consumption(obj, visited_addresses)
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def create_dummy_action() -> torch.Tensor:
|
||||
return torch.randn(4)
|
||||
|
||||
|
||||
def dict_properties() -> list:
|
||||
return ["state", "next_state"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_state() -> dict:
|
||||
return create_dummy_state()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def next_dummy_state() -> dict:
|
||||
return create_dummy_state()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_action() -> torch.Tensor:
|
||||
return torch.randn(4)
|
||||
|
||||
|
||||
def test_empty_buffer_sample_raises_error(replay_buffer):
|
||||
assert len(replay_buffer) == 0, "Replay buffer should be empty."
|
||||
assert replay_buffer.capacity == 10, "Replay buffer capacity should be 10."
|
||||
with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"):
|
||||
replay_buffer.sample(1)
|
||||
|
||||
|
||||
def test_zero_capacity_buffer_raises_error():
|
||||
with pytest.raises(ValueError, match="Capacity must be greater than 0."):
|
||||
ReplayBuffer(0, "cpu", ["observation", "next_observation"])
|
||||
|
||||
|
||||
def test_add_transition(replay_buffer, dummy_state, dummy_action):
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
assert len(replay_buffer) == 1, "Replay buffer should have one transition after adding."
|
||||
assert torch.equal(replay_buffer.actions[0], dummy_action), (
|
||||
"Action should be equal to the first transition."
|
||||
)
|
||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition."
|
||||
assert not replay_buffer.dones[0], "Done should be False for the first transition."
|
||||
assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition."
|
||||
|
||||
for dim in state_dims():
|
||||
assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), (
|
||||
"Observation should be equal to the first transition."
|
||||
)
|
||||
assert torch.equal(replay_buffer.next_states[dim][0], dummy_state[dim]), (
|
||||
"Next observation should be equal to the first transition."
|
||||
)
|
||||
|
||||
|
||||
def test_add_over_capacity():
|
||||
replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"])
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
|
||||
|
||||
assert len(replay_buffer) == 2, "Replay buffer should have 2 transitions after adding 3."
|
||||
|
||||
for dim in state_dims():
|
||||
assert torch.equal(replay_buffer.states[dim][0], dummy_state_3[dim]), (
|
||||
"Observation should be equal to the first transition."
|
||||
)
|
||||
assert torch.equal(replay_buffer.next_states[dim][0], dummy_state_3[dim]), (
|
||||
"Next observation should be equal to the first transition."
|
||||
)
|
||||
|
||||
assert torch.equal(replay_buffer.actions[0], dummy_action_3), (
|
||||
"Action should be equal to the last transition."
|
||||
)
|
||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition."
|
||||
assert replay_buffer.dones[0], "Done should be True for the first transition."
|
||||
assert replay_buffer.truncateds[0], "Truncated should be True for the first transition."
|
||||
|
||||
|
||||
def test_sample_from_empty_buffer(replay_buffer):
|
||||
with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"):
|
||||
replay_buffer.sample(1)
|
||||
|
||||
|
||||
def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state, dummy_action):
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False)
|
||||
got_batch_transition = replay_buffer.sample(1)
|
||||
|
||||
expected_batch_transition = BatchTransition(
|
||||
state=clone_state(dummy_state),
|
||||
action=dummy_action.clone(),
|
||||
reward=1.0,
|
||||
next_state=clone_state(next_dummy_state),
|
||||
done=False,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
for buffer_property in dict_properties():
|
||||
for k, v in expected_batch_transition[buffer_property].items():
|
||||
got_state = got_batch_transition[buffer_property][k]
|
||||
|
||||
assert got_state.shape[0] == 1, f"{k} should have 1 transition."
|
||||
assert got_state.device.type == "cpu", f"{k} should be on cpu."
|
||||
|
||||
assert torch.equal(got_state[0], v), f"{k} should be equal to the expected batch transition."
|
||||
|
||||
for key, _value in expected_batch_transition.items():
|
||||
if key in dict_properties():
|
||||
continue
|
||||
|
||||
got_value = got_batch_transition[key]
|
||||
|
||||
v_tensor = expected_batch_transition[key]
|
||||
if not isinstance(v_tensor, torch.Tensor):
|
||||
v_tensor = torch.tensor(v_tensor)
|
||||
|
||||
assert got_value.shape[0] == 1, f"{key} should have 1 transition."
|
||||
assert got_value.device.type == "cpu", f"{key} should be on cpu."
|
||||
assert torch.equal(got_value[0], v_tensor), f"{key} should be equal to the expected batch transition."
|
||||
|
||||
|
||||
def test_sample_with_batch_bigger_than_buffer_size(
|
||||
replay_buffer, dummy_state, next_dummy_state, dummy_action
|
||||
):
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False)
|
||||
got_batch_transition = replay_buffer.sample(10)
|
||||
|
||||
expected_batch_transition = BatchTransition(
|
||||
state=dummy_state,
|
||||
action=dummy_action,
|
||||
reward=1.0,
|
||||
next_state=next_dummy_state,
|
||||
done=False,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
for buffer_property in dict_properties():
|
||||
for k in expected_batch_transition[buffer_property]:
|
||||
got_state = got_batch_transition[buffer_property][k]
|
||||
|
||||
assert got_state.shape[0] == 1, f"{k} should have 1 transition."
|
||||
|
||||
for key in expected_batch_transition:
|
||||
if key in dict_properties():
|
||||
continue
|
||||
|
||||
got_value = got_batch_transition[key]
|
||||
assert got_value.shape[0] == 1, f"{key} should have 1 transition."
|
||||
|
||||
|
||||
def test_sample_batch(replay_buffer):
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 2.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 3.0, dummy_state_3, True, True)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 4.0, dummy_state_4, True, True)
|
||||
|
||||
dummy_states = [dummy_state_1, dummy_state_2, dummy_state_3, dummy_state_4]
|
||||
dummy_actions = [dummy_action_1, dummy_action_2, dummy_action_3, dummy_action_4]
|
||||
|
||||
got_batch_transition = replay_buffer.sample(3)
|
||||
|
||||
for buffer_property in dict_properties():
|
||||
for k in got_batch_transition[buffer_property]:
|
||||
got_state = got_batch_transition[buffer_property][k]
|
||||
|
||||
assert got_state.shape[0] == 3, f"{k} should have 3 transition."
|
||||
|
||||
for got_state_item in got_state:
|
||||
assert any(torch.equal(got_state_item, dummy_state[k]) for dummy_state in dummy_states), (
|
||||
f"{k} should be equal to one of the dummy states."
|
||||
)
|
||||
|
||||
for got_action_item in got_batch_transition["action"]:
|
||||
assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), (
|
||||
"Actions should be equal to the dummy actions."
|
||||
)
|
||||
|
||||
for k in got_batch_transition:
|
||||
if k in dict_properties() or k == "complementary_info":
|
||||
continue
|
||||
|
||||
got_value = got_batch_transition[k]
|
||||
assert got_value.shape[0] == 3, f"{k} should have 3 transition."
|
||||
|
||||
|
||||
def test_to_lerobot_dataset_with_empty_buffer(replay_buffer):
|
||||
with pytest.raises(ValueError, match="The replay buffer is empty. Cannot convert to a dataset."):
|
||||
replay_buffer.to_lerobot_dataset("dummy_repo")
|
||||
|
||||
|
||||
def test_to_lerobot_dataset(tmp_path):
|
||||
ds, buffer = create_dataset_from_replay_buffer(tmp_path)
|
||||
|
||||
assert len(ds) == len(buffer), "Dataset should have the same size as the Replay Buffer"
|
||||
assert ds.fps == 1, "FPS should be 1"
|
||||
assert ds.repo_id == "dummy/repo", "The dataset should have `dummy/repo` repo id"
|
||||
|
||||
for dim in state_dims():
|
||||
assert dim in ds.features
|
||||
assert ds.features[dim]["shape"] == buffer.states[dim][0].shape
|
||||
|
||||
assert ds.num_episodes == 2
|
||||
assert ds.num_frames == 4
|
||||
|
||||
for j, value in enumerate(ds):
|
||||
print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j]))
|
||||
|
||||
for i in range(len(ds)):
|
||||
for feature, value in ds[i].items():
|
||||
if feature == "action":
|
||||
assert torch.equal(value, buffer.actions[i])
|
||||
elif feature == "next.reward":
|
||||
assert torch.equal(value, buffer.rewards[i])
|
||||
elif feature == "next.done":
|
||||
assert torch.equal(value, buffer.dones[i])
|
||||
elif feature == "observation.image":
|
||||
# Tenssor -> numpy is not precise, so we have some diff there
|
||||
# TODO: Check and fix it
|
||||
torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003)
|
||||
elif feature == "observation.state":
|
||||
assert torch.equal(value, buffer.states["observation.state"][i])
|
||||
|
||||
|
||||
def test_from_lerobot_dataset(tmp_path):
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer = create_empty_replay_buffer()
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
|
||||
|
||||
root = tmp_path / "test"
|
||||
ds = replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root)
|
||||
|
||||
reconverted_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False
|
||||
)
|
||||
|
||||
assert len(reconverted_buffer) == 4, "Reconverted Replay buffer should have the same size as original"
|
||||
|
||||
assert torch.equal(reconverted_buffer.actions, replay_buffer.actions), (
|
||||
"Actions from converted buffer should be equal to the original replay buffer."
|
||||
)
|
||||
assert torch.equal(reconverted_buffer.rewards, replay_buffer.rewards), (
|
||||
"Rewards from converted buffer should be equal to the original replay buffer."
|
||||
)
|
||||
assert torch.equal(reconverted_buffer.dones, replay_buffer.dones), (
|
||||
"Dones from converted buffer should be equal to the original replay buffer."
|
||||
)
|
||||
|
||||
# Lerobot DS haven't supported truncateds yet
|
||||
expected_truncateds = torch.zeros(replay_buffer.truncateds.shape[0]).bool()
|
||||
assert torch.equal(reconverted_buffer.truncateds, expected_truncateds), (
|
||||
"Truncateds from converted buffer should be equal False"
|
||||
)
|
||||
|
||||
assert torch.equal(
|
||||
replay_buffer.states["observation.state"], reconverted_buffer.states["observation.state"]
|
||||
), "State should be the same after converting to dataset and return back"
|
||||
|
||||
for i in range(4):
|
||||
torch.testing.assert_close(
|
||||
replay_buffer.states["observation.image"][i],
|
||||
reconverted_buffer.states["observation.image"][i],
|
||||
rtol=0.4,
|
||||
atol=0.004,
|
||||
)
|
||||
|
||||
# The 2, 3 frames have done flag, so their values will be equal to the current state
|
||||
for i in range(2):
|
||||
# In the current implementation we take the next state from the `states` and ignore `next_states`
|
||||
next_index = (i + 1) % 4
|
||||
|
||||
torch.testing.assert_close(
|
||||
replay_buffer.states["observation.image"][next_index],
|
||||
reconverted_buffer.next_states["observation.image"][i],
|
||||
rtol=0.4,
|
||||
atol=0.004,
|
||||
)
|
||||
|
||||
for i in range(2, 4):
|
||||
assert torch.equal(
|
||||
replay_buffer.states["observation.state"][i],
|
||||
reconverted_buffer.next_states["observation.state"][i],
|
||||
)
|
||||
|
||||
|
||||
def test_buffer_sample_alignment():
|
||||
# Initialize buffer
|
||||
buffer = ReplayBuffer(capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu")
|
||||
|
||||
# Fill buffer with patterned data
|
||||
for i in range(100):
|
||||
signature = float(i) / 100.0
|
||||
state = {"state_value": torch.tensor([[signature]]).float()}
|
||||
action = torch.tensor([[2.0 * signature]]).float()
|
||||
reward = 3.0 * signature
|
||||
|
||||
is_end = (i + 1) % 10 == 0
|
||||
if is_end:
|
||||
next_state = {"state_value": torch.tensor([[signature]]).float()}
|
||||
done = True
|
||||
else:
|
||||
next_signature = float(i + 1) / 100.0
|
||||
next_state = {"state_value": torch.tensor([[next_signature]]).float()}
|
||||
done = False
|
||||
|
||||
buffer.add(state, action, reward, next_state, done, False)
|
||||
|
||||
# Sample and verify
|
||||
batch = buffer.sample(50)
|
||||
|
||||
for i in range(50):
|
||||
state_sig = batch["state"]["state_value"][i].item()
|
||||
action_val = batch["action"][i].item()
|
||||
reward_val = batch["reward"][i].item()
|
||||
next_state_sig = batch["next_state"]["state_value"][i].item()
|
||||
is_done = batch["done"][i].item() > 0.5
|
||||
|
||||
# Verify relationships
|
||||
assert abs(action_val - 2.0 * state_sig) < 1e-4, (
|
||||
f"Action {action_val} should be 2x state signature {state_sig}"
|
||||
)
|
||||
|
||||
assert abs(reward_val - 3.0 * state_sig) < 1e-4, (
|
||||
f"Reward {reward_val} should be 3x state signature {state_sig}"
|
||||
)
|
||||
|
||||
if is_done:
|
||||
assert abs(next_state_sig - state_sig) < 1e-4, (
|
||||
f"For done states, next_state {next_state_sig} should equal state {state_sig}"
|
||||
)
|
||||
else:
|
||||
# Either it's the next sequential state (+0.01) or same state (for episode boundaries)
|
||||
valid_next = (
|
||||
abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
|
||||
)
|
||||
assert valid_next, (
|
||||
f"Next state {next_state_sig} should be either state+0.01 or same as state {state_sig}"
|
||||
)
|
||||
|
||||
|
||||
def test_memory_optimization():
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
dummy_state_2 = create_dummy_state()
|
||||
dummy_action_2 = create_dummy_action()
|
||||
|
||||
dummy_state_3 = create_dummy_state()
|
||||
dummy_action_3 = create_dummy_action()
|
||||
|
||||
dummy_state_4 = create_dummy_state()
|
||||
dummy_action_4 = create_dummy_action()
|
||||
|
||||
replay_buffer = create_empty_replay_buffer()
|
||||
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False)
|
||||
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False)
|
||||
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False)
|
||||
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
|
||||
|
||||
optimized_replay_buffer = create_empty_replay_buffer(True)
|
||||
optimized_replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False)
|
||||
optimized_replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False)
|
||||
optimized_replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False)
|
||||
optimized_replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, None, True, True)
|
||||
|
||||
assert get_object_memory(optimized_replay_buffer) < get_object_memory(replay_buffer), (
|
||||
"Optimized replay buffer should be smaller than the original replay buffer"
|
||||
)
|
||||
|
||||
|
||||
def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_function(dummy_state, dummy_action):
|
||||
def dummy_image_augmentation_function(x):
|
||||
return torch.ones_like(x) * 10
|
||||
|
||||
replay_buffer = create_empty_replay_buffer(
|
||||
use_drq=True, image_augmentation_function=dummy_image_augmentation_function
|
||||
)
|
||||
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
|
||||
sampled_transitions = replay_buffer.sample(1)
|
||||
assert torch.all(sampled_transitions["state"]["observation.image"] == 10), (
|
||||
"Image augmentations should be applied"
|
||||
)
|
||||
assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), (
|
||||
"Image augmentations should be applied"
|
||||
)
|
||||
|
||||
|
||||
def test_check_image_augmentations_with_drq_and_default_image_augmentation_function(
|
||||
dummy_state, dummy_action
|
||||
):
|
||||
replay_buffer = create_empty_replay_buffer(use_drq=True)
|
||||
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
|
||||
# Let's check that it doesn't fail and shapes are correct
|
||||
sampled_transitions = replay_buffer.sample(1)
|
||||
assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84)
|
||||
assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84)
|
||||
|
||||
|
||||
def test_random_crop_vectorized_basic():
|
||||
# Create a batch of 2 images with known patterns
|
||||
batch_size, channels, height, width = 2, 3, 10, 8
|
||||
images = torch.zeros((batch_size, channels, height, width))
|
||||
|
||||
# Fill with unique values for testing
|
||||
for b in range(batch_size):
|
||||
images[b] = b + 1
|
||||
|
||||
crop_size = (6, 4) # Smaller than original
|
||||
cropped = random_crop_vectorized(images, crop_size)
|
||||
|
||||
# Check output shape
|
||||
assert cropped.shape == (batch_size, channels, *crop_size)
|
||||
|
||||
# Check that values are preserved (should be either 1s or 2s for respective batches)
|
||||
assert torch.all(cropped[0] == 1)
|
||||
assert torch.all(cropped[1] == 2)
|
||||
|
||||
|
||||
def test_random_crop_vectorized_invalid_size():
|
||||
images = torch.zeros((2, 3, 10, 8))
|
||||
|
||||
# Test crop size larger than image
|
||||
with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"):
|
||||
random_crop_vectorized(images, (12, 8))
|
||||
|
||||
with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"):
|
||||
random_crop_vectorized(images, (10, 10))
|
||||
Reference in New Issue
Block a user