forked from tangger/lerobot
chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -22,11 +22,12 @@ import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def state_dims() -> list[str]:
|
||||
return ["observation.image", "observation.state"]
|
||||
return [OBS_IMAGE, OBS_STATE]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -61,10 +62,10 @@ def create_random_image() -> torch.Tensor:
|
||||
|
||||
def create_dummy_transition() -> dict:
|
||||
return {
|
||||
"observation.image": create_random_image(),
|
||||
OBS_IMAGE: create_random_image(),
|
||||
"action": torch.randn(4),
|
||||
"reward": torch.tensor(1.0),
|
||||
"observation.state": torch.randn(
|
||||
OBS_STATE: torch.randn(
|
||||
10,
|
||||
),
|
||||
"done": torch.tensor(False),
|
||||
@@ -98,8 +99,8 @@ def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayB
|
||||
|
||||
def create_dummy_state() -> dict:
|
||||
return {
|
||||
"observation.image": create_random_image(),
|
||||
"observation.state": torch.randn(
|
||||
OBS_IMAGE: create_random_image(),
|
||||
OBS_STATE: torch.randn(
|
||||
10,
|
||||
),
|
||||
}
|
||||
@@ -180,7 +181,7 @@ def test_empty_buffer_sample_raises_error(replay_buffer):
|
||||
|
||||
def test_zero_capacity_buffer_raises_error():
|
||||
with pytest.raises(ValueError, match="Capacity must be greater than 0."):
|
||||
ReplayBuffer(0, "cpu", ["observation", "next_observation"])
|
||||
ReplayBuffer(0, "cpu", [OBS_STR, "next_observation"])
|
||||
|
||||
|
||||
def test_add_transition(replay_buffer, dummy_state, dummy_action):
|
||||
@@ -203,7 +204,7 @@ def test_add_transition(replay_buffer, dummy_state, dummy_action):
|
||||
|
||||
|
||||
def test_add_over_capacity():
|
||||
replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"])
|
||||
replay_buffer = ReplayBuffer(2, "cpu", [OBS_STR, "next_observation"])
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
@@ -373,7 +374,7 @@ def test_to_lerobot_dataset(tmp_path):
|
||||
assert ds.num_frames == 4
|
||||
|
||||
for j, value in enumerate(ds):
|
||||
print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j]))
|
||||
print(torch.equal(value[OBS_IMAGE], buffer.next_states[OBS_IMAGE][j]))
|
||||
|
||||
for i in range(len(ds)):
|
||||
for feature, value in ds[i].items():
|
||||
@@ -383,12 +384,12 @@ def test_to_lerobot_dataset(tmp_path):
|
||||
assert torch.equal(value, buffer.rewards[i])
|
||||
elif feature == "next.done":
|
||||
assert torch.equal(value, buffer.dones[i])
|
||||
elif feature == "observation.image":
|
||||
elif feature == OBS_IMAGE:
|
||||
# Tensor -> numpy is not precise, so we have some diff there
|
||||
# TODO: Check and fix it
|
||||
torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003)
|
||||
elif feature == "observation.state":
|
||||
assert torch.equal(value, buffer.states["observation.state"][i])
|
||||
torch.testing.assert_close(value, buffer.states[OBS_IMAGE][i], rtol=0.3, atol=0.003)
|
||||
elif feature == OBS_STATE:
|
||||
assert torch.equal(value, buffer.states[OBS_STATE][i])
|
||||
|
||||
|
||||
def test_from_lerobot_dataset(tmp_path):
|
||||
@@ -436,14 +437,14 @@ def test_from_lerobot_dataset(tmp_path):
|
||||
)
|
||||
|
||||
assert torch.equal(
|
||||
replay_buffer.states["observation.state"][: len(replay_buffer)],
|
||||
reconverted_buffer.states["observation.state"][: len(replay_buffer)],
|
||||
replay_buffer.states[OBS_STATE][: len(replay_buffer)],
|
||||
reconverted_buffer.states[OBS_STATE][: len(replay_buffer)],
|
||||
), "State should be the same after converting to dataset and return back"
|
||||
|
||||
for i in range(4):
|
||||
torch.testing.assert_close(
|
||||
replay_buffer.states["observation.image"][i],
|
||||
reconverted_buffer.states["observation.image"][i],
|
||||
replay_buffer.states[OBS_IMAGE][i],
|
||||
reconverted_buffer.states[OBS_IMAGE][i],
|
||||
rtol=0.4,
|
||||
atol=0.004,
|
||||
)
|
||||
@@ -454,16 +455,16 @@ def test_from_lerobot_dataset(tmp_path):
|
||||
next_index = (i + 1) % 4
|
||||
|
||||
torch.testing.assert_close(
|
||||
replay_buffer.states["observation.image"][next_index],
|
||||
reconverted_buffer.next_states["observation.image"][i],
|
||||
replay_buffer.states[OBS_IMAGE][next_index],
|
||||
reconverted_buffer.next_states[OBS_IMAGE][i],
|
||||
rtol=0.4,
|
||||
atol=0.004,
|
||||
)
|
||||
|
||||
for i in range(2, 4):
|
||||
assert torch.equal(
|
||||
replay_buffer.states["observation.state"][i],
|
||||
reconverted_buffer.next_states["observation.state"][i],
|
||||
replay_buffer.states[OBS_STATE][i],
|
||||
reconverted_buffer.next_states[OBS_STATE][i],
|
||||
)
|
||||
|
||||
|
||||
@@ -563,10 +564,8 @@ def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_functio
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
|
||||
sampled_transitions = replay_buffer.sample(1)
|
||||
assert torch.all(sampled_transitions["state"]["observation.image"] == 10), (
|
||||
"Image augmentations should be applied"
|
||||
)
|
||||
assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), (
|
||||
assert torch.all(sampled_transitions["state"][OBS_IMAGE] == 10), "Image augmentations should be applied"
|
||||
assert torch.all(sampled_transitions["next_state"][OBS_IMAGE] == 10), (
|
||||
"Image augmentations should be applied"
|
||||
)
|
||||
|
||||
@@ -580,8 +579,8 @@ def test_check_image_augmentations_with_drq_and_default_image_augmentation_funct
|
||||
|
||||
# Let's check that it doesn't fail and shapes are correct
|
||||
sampled_transitions = replay_buffer.sample(1)
|
||||
assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84)
|
||||
assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84)
|
||||
assert sampled_transitions["state"][OBS_IMAGE].shape == (1, 3, 84, 84)
|
||||
assert sampled_transitions["next_state"][OBS_IMAGE].shape == (1, 3, 84, 84)
|
||||
|
||||
|
||||
def test_random_crop_vectorized_basic():
|
||||
@@ -620,7 +619,7 @@ def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer:
|
||||
buffer = ReplayBuffer(
|
||||
capacity=capacity,
|
||||
device="cpu",
|
||||
state_keys=["observation.image", "observation.state"],
|
||||
state_keys=[OBS_IMAGE, OBS_STATE],
|
||||
storage_device="cpu",
|
||||
)
|
||||
|
||||
@@ -628,8 +627,8 @@ def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer:
|
||||
img = torch.ones(3, 128, 128) * i
|
||||
state_vec = torch.arange(11).float() + i
|
||||
state = {
|
||||
"observation.image": img,
|
||||
"observation.state": state_vec,
|
||||
OBS_IMAGE: img,
|
||||
OBS_STATE: state_vec,
|
||||
}
|
||||
buffer.add(
|
||||
state=state,
|
||||
@@ -648,14 +647,14 @@ def test_async_iterator_shapes_basic():
|
||||
iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=1)
|
||||
batch = next(iterator)
|
||||
|
||||
images = batch["state"]["observation.image"]
|
||||
states = batch["state"]["observation.state"]
|
||||
images = batch["state"][OBS_IMAGE]
|
||||
states = batch["state"][OBS_STATE]
|
||||
|
||||
assert images.shape == (batch_size, 3, 128, 128)
|
||||
assert states.shape == (batch_size, 11)
|
||||
|
||||
next_images = batch["next_state"]["observation.image"]
|
||||
next_states = batch["next_state"]["observation.state"]
|
||||
next_images = batch["next_state"][OBS_IMAGE]
|
||||
next_states = batch["next_state"][OBS_STATE]
|
||||
|
||||
assert next_images.shape == (batch_size, 3, 128, 128)
|
||||
assert next_states.shape == (batch_size, 11)
|
||||
@@ -668,13 +667,13 @@ def test_async_iterator_multiple_iterations():
|
||||
|
||||
for _ in range(5):
|
||||
batch = next(iterator)
|
||||
images = batch["state"]["observation.image"]
|
||||
states = batch["state"]["observation.state"]
|
||||
images = batch["state"][OBS_IMAGE]
|
||||
states = batch["state"][OBS_STATE]
|
||||
assert images.shape == (batch_size, 3, 128, 128)
|
||||
assert states.shape == (batch_size, 11)
|
||||
|
||||
next_images = batch["next_state"]["observation.image"]
|
||||
next_states = batch["next_state"]["observation.state"]
|
||||
next_images = batch["next_state"][OBS_IMAGE]
|
||||
next_states = batch["next_state"][OBS_STATE]
|
||||
assert next_images.shape == (batch_size, 3, 128, 128)
|
||||
assert next_states.shape == (batch_size, 11)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user