Fix test comparing uninitialized array segment

The test was inadvertently comparing uninitialized parts of the array,
which could lead to inconsistent or undefined results. This fix ensures
only the relevant, properly initialized sections are checked.

Co-authored-by: Eugene Mironov <helper2424@gmail.com>
This commit is contained in:
AdilZouitine
2025-04-22 15:13:10 +02:00
parent 4ce3362724
commit 5231752487

View File

@@ -401,26 +401,27 @@ def test_from_lerobot_dataset(tmp_path):
ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False
)
assert len(reconverted_buffer) == 4, "Reconverted Replay buffer should have the same size as original"
assert torch.equal(reconverted_buffer.actions, replay_buffer.actions), (
"Actions from converted buffer should be equal to the original replay buffer."
)
assert torch.equal(reconverted_buffer.rewards, replay_buffer.rewards), (
"Rewards from converted buffer should be equal to the original replay buffer."
)
assert torch.equal(reconverted_buffer.dones, replay_buffer.dones), (
"Dones from converted buffer should be equal to the original replay buffer."
)
# Check only the part of the buffer that's actually filled with data
assert torch.equal(
reconverted_buffer.actions[: len(replay_buffer)],
replay_buffer.actions[: len(replay_buffer)],
), "Actions from converted buffer should be equal to the original replay buffer."
assert torch.equal(
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
), "Rewards from converted buffer should be equal to the original replay buffer."
assert torch.equal(
reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)]
), "Dones from converted buffer should be equal to the original replay buffer."
# Lerobot DS haven't supported truncateds yet
expected_truncateds = torch.zeros(replay_buffer.truncateds.shape[0]).bool()
assert torch.equal(reconverted_buffer.truncateds, expected_truncateds), (
expected_truncateds = torch.zeros(len(replay_buffer)).bool()
assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), (
"Truncateds from converted buffer should be equal False"
)
assert torch.equal(
replay_buffer.states["observation.state"], reconverted_buffer.states["observation.state"]
replay_buffer.states["observation.state"][: len(replay_buffer)],
reconverted_buffer.states["observation.state"][: len(replay_buffer)],
), "State should be the same after converting to dataset and return back"
for i in range(4):