Fix test comparing uninitialized array segment
The test was inadvertently comparing uninitialized parts of the array, which could lead to inconsistent or undefined results. This fix ensures only the relevant, properly initialized sections are checked. Co-authored-by: Eugene Mironov <helper2424@gmail.com>
This commit is contained in:
@@ -401,26 +401,27 @@ def test_from_lerobot_dataset(tmp_path):
|
|||||||
ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False
|
ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(reconverted_buffer) == 4, "Reconverted Replay buffer should have the same size as original"
|
# Check only the part of the buffer that's actually filled with data
|
||||||
|
assert torch.equal(
|
||||||
assert torch.equal(reconverted_buffer.actions, replay_buffer.actions), (
|
reconverted_buffer.actions[: len(replay_buffer)],
|
||||||
"Actions from converted buffer should be equal to the original replay buffer."
|
replay_buffer.actions[: len(replay_buffer)],
|
||||||
)
|
), "Actions from converted buffer should be equal to the original replay buffer."
|
||||||
assert torch.equal(reconverted_buffer.rewards, replay_buffer.rewards), (
|
assert torch.equal(
|
||||||
"Rewards from converted buffer should be equal to the original replay buffer."
|
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
|
||||||
)
|
), "Rewards from converted buffer should be equal to the original replay buffer."
|
||||||
assert torch.equal(reconverted_buffer.dones, replay_buffer.dones), (
|
assert torch.equal(
|
||||||
"Dones from converted buffer should be equal to the original replay buffer."
|
reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)]
|
||||||
)
|
), "Dones from converted buffer should be equal to the original replay buffer."
|
||||||
|
|
||||||
# Lerobot DS haven't supported truncateds yet
|
# Lerobot DS haven't supported truncateds yet
|
||||||
expected_truncateds = torch.zeros(replay_buffer.truncateds.shape[0]).bool()
|
expected_truncateds = torch.zeros(len(replay_buffer)).bool()
|
||||||
assert torch.equal(reconverted_buffer.truncateds, expected_truncateds), (
|
assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), (
|
||||||
"Truncateds from converted buffer should be equal False"
|
"Truncateds from converted buffer should be equal False"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
replay_buffer.states["observation.state"], reconverted_buffer.states["observation.state"]
|
replay_buffer.states["observation.state"][: len(replay_buffer)],
|
||||||
|
reconverted_buffer.states["observation.state"][: len(replay_buffer)],
|
||||||
), "State should be the same after converting to dataset and return back"
|
), "State should be the same after converting to dataset and return back"
|
||||||
|
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
|
|||||||
Reference in New Issue
Block a user