From 5231752487270339e56056196f6da1c6d5b180d6 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 22 Apr 2025 15:13:10 +0200 Subject: [PATCH] Fix test comparing uninitialized array segment The test was inadvertently comparing uninitialized parts of the array, which could lead to inconsistent or undefined results. This fix ensures only the relevant, properly initialized sections are checked. Co-authored-by: Eugene Mironov --- tests/server/test_replay_buffer.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/server/test_replay_buffer.py b/tests/server/test_replay_buffer.py index 55843de8..5d1cd62f 100644 --- a/tests/server/test_replay_buffer.py +++ b/tests/server/test_replay_buffer.py @@ -401,26 +401,27 @@ def test_from_lerobot_dataset(tmp_path): ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False ) - assert len(reconverted_buffer) == 4, "Reconverted Replay buffer should have the same size as original" - - assert torch.equal(reconverted_buffer.actions, replay_buffer.actions), ( - "Actions from converted buffer should be equal to the original replay buffer." - ) - assert torch.equal(reconverted_buffer.rewards, replay_buffer.rewards), ( - "Rewards from converted buffer should be equal to the original replay buffer." - ) - assert torch.equal(reconverted_buffer.dones, replay_buffer.dones), ( - "Dones from converted buffer should be equal to the original replay buffer." - ) + # Check only the part of the buffer that's actually filled with data + assert torch.equal( + reconverted_buffer.actions[: len(replay_buffer)], + replay_buffer.actions[: len(replay_buffer)], + ), "Actions from converted buffer should be equal to the original replay buffer." + assert torch.equal( + reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)] + ), "Rewards from converted buffer should be equal to the original replay buffer." + assert torch.equal( + reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)] + ), "Dones from converted buffer should be equal to the original replay buffer." # Lerobot DS haven't supported truncateds yet - expected_truncateds = torch.zeros(replay_buffer.truncateds.shape[0]).bool() - assert torch.equal(reconverted_buffer.truncateds, expected_truncateds), ( + expected_truncateds = torch.zeros(len(replay_buffer)).bool() + assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), ( "Truncateds from converted buffer should be equal False" ) assert torch.equal( - replay_buffer.states["observation.state"], reconverted_buffer.states["observation.state"] + replay_buffer.states["observation.state"][: len(replay_buffer)], + reconverted_buffer.states["observation.state"][: len(replay_buffer)], ), "State should be the same after converting to dataset and return back" for i in range(4):