diff --git a/tests/server/test_replay_buffer.py b/tests/server/test_replay_buffer.py index 55843de8..5d1cd62f 100644 --- a/tests/server/test_replay_buffer.py +++ b/tests/server/test_replay_buffer.py @@ -401,26 +401,27 @@ def test_from_lerobot_dataset(tmp_path): ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False ) - assert len(reconverted_buffer) == 4, "Reconverted Replay buffer should have the same size as original" - - assert torch.equal(reconverted_buffer.actions, replay_buffer.actions), ( - "Actions from converted buffer should be equal to the original replay buffer." - ) - assert torch.equal(reconverted_buffer.rewards, replay_buffer.rewards), ( - "Rewards from converted buffer should be equal to the original replay buffer." - ) - assert torch.equal(reconverted_buffer.dones, replay_buffer.dones), ( - "Dones from converted buffer should be equal to the original replay buffer." - ) + # Check only the part of the buffer that's actually filled with data + assert torch.equal( + reconverted_buffer.actions[: len(replay_buffer)], + replay_buffer.actions[: len(replay_buffer)], + ), "Actions from converted buffer should be equal to the original replay buffer." + assert torch.equal( + reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)] + ), "Rewards from converted buffer should be equal to the original replay buffer." + assert torch.equal( + reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)] + ), "Dones from converted buffer should be equal to the original replay buffer." # Lerobot DS haven't supported truncateds yet - expected_truncateds = torch.zeros(replay_buffer.truncateds.shape[0]).bool() - assert torch.equal(reconverted_buffer.truncateds, expected_truncateds), ( + expected_truncateds = torch.zeros(len(replay_buffer)).bool() + assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), ( "Truncateds from converted buffer should be equal False" ) assert torch.equal( - replay_buffer.states["observation.state"], reconverted_buffer.states["observation.state"] + replay_buffer.states["observation.state"][: len(replay_buffer)], + reconverted_buffer.states["observation.state"][: len(replay_buffer)], ), "State should be the same after converting to dataset and return back" for i in range(4):