chore: replace hard-coded obs values with constants throughout all the source code (#2037)

* chore: replace hard-coded OBS values with constants throughout all the source code

* chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-25 15:36:47 +02:00
committed by GitHub
parent ddba994d73
commit 43d878a102
52 changed files with 659 additions and 649 deletions

View File

@@ -22,11 +22,12 @@ import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR
from tests.fixtures.constants import DUMMY_REPO_ID
def state_dims() -> list[str]:
return ["observation.image", "observation.state"]
return [OBS_IMAGE, OBS_STATE]
@pytest.fixture
@@ -61,10 +62,10 @@ def create_random_image() -> torch.Tensor:
def create_dummy_transition() -> dict:
return {
"observation.image": create_random_image(),
OBS_IMAGE: create_random_image(),
"action": torch.randn(4),
"reward": torch.tensor(1.0),
"observation.state": torch.randn(
OBS_STATE: torch.randn(
10,
),
"done": torch.tensor(False),
@@ -98,8 +99,8 @@ def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayB
def create_dummy_state() -> dict:
return {
"observation.image": create_random_image(),
"observation.state": torch.randn(
OBS_IMAGE: create_random_image(),
OBS_STATE: torch.randn(
10,
),
}
@@ -180,7 +181,7 @@ def test_empty_buffer_sample_raises_error(replay_buffer):
def test_zero_capacity_buffer_raises_error():
with pytest.raises(ValueError, match="Capacity must be greater than 0."):
ReplayBuffer(0, "cpu", ["observation", "next_observation"])
ReplayBuffer(0, "cpu", [OBS_STR, "next_observation"])
def test_add_transition(replay_buffer, dummy_state, dummy_action):
@@ -203,7 +204,7 @@ def test_add_transition(replay_buffer, dummy_state, dummy_action):
def test_add_over_capacity():
replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"])
replay_buffer = ReplayBuffer(2, "cpu", [OBS_STR, "next_observation"])
dummy_state_1 = create_dummy_state()
dummy_action_1 = create_dummy_action()
@@ -373,7 +374,7 @@ def test_to_lerobot_dataset(tmp_path):
assert ds.num_frames == 4
for j, value in enumerate(ds):
print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j]))
print(torch.equal(value[OBS_IMAGE], buffer.next_states[OBS_IMAGE][j]))
for i in range(len(ds)):
for feature, value in ds[i].items():
@@ -383,12 +384,12 @@ def test_to_lerobot_dataset(tmp_path):
assert torch.equal(value, buffer.rewards[i])
elif feature == "next.done":
assert torch.equal(value, buffer.dones[i])
elif feature == "observation.image":
elif feature == OBS_IMAGE:
# Tensor -> numpy is not precise, so we have some diff there
# TODO: Check and fix it
torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003)
elif feature == "observation.state":
assert torch.equal(value, buffer.states["observation.state"][i])
torch.testing.assert_close(value, buffer.states[OBS_IMAGE][i], rtol=0.3, atol=0.003)
elif feature == OBS_STATE:
assert torch.equal(value, buffer.states[OBS_STATE][i])
def test_from_lerobot_dataset(tmp_path):
@@ -436,14 +437,14 @@ def test_from_lerobot_dataset(tmp_path):
)
assert torch.equal(
replay_buffer.states["observation.state"][: len(replay_buffer)],
reconverted_buffer.states["observation.state"][: len(replay_buffer)],
replay_buffer.states[OBS_STATE][: len(replay_buffer)],
reconverted_buffer.states[OBS_STATE][: len(replay_buffer)],
), "State should be the same after converting to dataset and return back"
for i in range(4):
torch.testing.assert_close(
replay_buffer.states["observation.image"][i],
reconverted_buffer.states["observation.image"][i],
replay_buffer.states[OBS_IMAGE][i],
reconverted_buffer.states[OBS_IMAGE][i],
rtol=0.4,
atol=0.004,
)
@@ -454,16 +455,16 @@ def test_from_lerobot_dataset(tmp_path):
next_index = (i + 1) % 4
torch.testing.assert_close(
replay_buffer.states["observation.image"][next_index],
reconverted_buffer.next_states["observation.image"][i],
replay_buffer.states[OBS_IMAGE][next_index],
reconverted_buffer.next_states[OBS_IMAGE][i],
rtol=0.4,
atol=0.004,
)
for i in range(2, 4):
assert torch.equal(
replay_buffer.states["observation.state"][i],
reconverted_buffer.next_states["observation.state"][i],
replay_buffer.states[OBS_STATE][i],
reconverted_buffer.next_states[OBS_STATE][i],
)
@@ -563,10 +564,8 @@ def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_functio
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
sampled_transitions = replay_buffer.sample(1)
assert torch.all(sampled_transitions["state"]["observation.image"] == 10), (
"Image augmentations should be applied"
)
assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), (
assert torch.all(sampled_transitions["state"][OBS_IMAGE] == 10), "Image augmentations should be applied"
assert torch.all(sampled_transitions["next_state"][OBS_IMAGE] == 10), (
"Image augmentations should be applied"
)
@@ -580,8 +579,8 @@ def test_check_image_augmentations_with_drq_and_default_image_augmentation_funct
# Let's check that it doesn't fail and shapes are correct
sampled_transitions = replay_buffer.sample(1)
assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84)
assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84)
assert sampled_transitions["state"][OBS_IMAGE].shape == (1, 3, 84, 84)
assert sampled_transitions["next_state"][OBS_IMAGE].shape == (1, 3, 84, 84)
def test_random_crop_vectorized_basic():
@@ -620,7 +619,7 @@ def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer:
buffer = ReplayBuffer(
capacity=capacity,
device="cpu",
state_keys=["observation.image", "observation.state"],
state_keys=[OBS_IMAGE, OBS_STATE],
storage_device="cpu",
)
@@ -628,8 +627,8 @@ def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer:
img = torch.ones(3, 128, 128) * i
state_vec = torch.arange(11).float() + i
state = {
"observation.image": img,
"observation.state": state_vec,
OBS_IMAGE: img,
OBS_STATE: state_vec,
}
buffer.add(
state=state,
@@ -648,14 +647,14 @@ def test_async_iterator_shapes_basic():
iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=1)
batch = next(iterator)
images = batch["state"]["observation.image"]
states = batch["state"]["observation.state"]
images = batch["state"][OBS_IMAGE]
states = batch["state"][OBS_STATE]
assert images.shape == (batch_size, 3, 128, 128)
assert states.shape == (batch_size, 11)
next_images = batch["next_state"]["observation.image"]
next_states = batch["next_state"]["observation.state"]
next_images = batch["next_state"][OBS_IMAGE]
next_states = batch["next_state"][OBS_STATE]
assert next_images.shape == (batch_size, 3, 128, 128)
assert next_states.shape == (batch_size, 11)
@@ -668,13 +667,13 @@ def test_async_iterator_multiple_iterations():
for _ in range(5):
batch = next(iterator)
images = batch["state"]["observation.image"]
states = batch["state"]["observation.state"]
images = batch["state"][OBS_IMAGE]
states = batch["state"][OBS_STATE]
assert images.shape == (batch_size, 3, 128, 128)
assert states.shape == (batch_size, 11)
next_images = batch["next_state"]["observation.image"]
next_states = batch["next_state"]["observation.state"]
next_images = batch["next_state"][OBS_IMAGE]
next_states = batch["next_state"][OBS_STATE]
assert next_images.shape == (batch_size, 3, 128, 128)
assert next_states.shape == (batch_size, 11)