forked from tangger/lerobot
Fix test comparing uninitialized array segment
The test was inadvertently comparing uninitialized parts of the array, which could lead to inconsistent or undefined results. This fix ensures only the relevant, properly initialized sections are checked. Co-authored-by: Eugene Mironov <helper2424@gmail.com>
This commit is contained in:
@@ -401,26 +401,27 @@ def test_from_lerobot_dataset(tmp_path):
|
||||
ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False
|
||||
)
|
||||
|
||||
assert len(reconverted_buffer) == 4, "Reconverted Replay buffer should have the same size as original"
|
||||
|
||||
assert torch.equal(reconverted_buffer.actions, replay_buffer.actions), (
|
||||
"Actions from converted buffer should be equal to the original replay buffer."
|
||||
)
|
||||
assert torch.equal(reconverted_buffer.rewards, replay_buffer.rewards), (
|
||||
"Rewards from converted buffer should be equal to the original replay buffer."
|
||||
)
|
||||
assert torch.equal(reconverted_buffer.dones, replay_buffer.dones), (
|
||||
"Dones from converted buffer should be equal to the original replay buffer."
|
||||
)
|
||||
# Check only the part of the buffer that's actually filled with data
|
||||
assert torch.equal(
|
||||
reconverted_buffer.actions[: len(replay_buffer)],
|
||||
replay_buffer.actions[: len(replay_buffer)],
|
||||
), "Actions from converted buffer should be equal to the original replay buffer."
|
||||
assert torch.equal(
|
||||
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
|
||||
), "Rewards from converted buffer should be equal to the original replay buffer."
|
||||
assert torch.equal(
|
||||
reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)]
|
||||
), "Dones from converted buffer should be equal to the original replay buffer."
|
||||
|
||||
# Lerobot DS haven't supported truncateds yet
|
||||
expected_truncateds = torch.zeros(replay_buffer.truncateds.shape[0]).bool()
|
||||
assert torch.equal(reconverted_buffer.truncateds, expected_truncateds), (
|
||||
expected_truncateds = torch.zeros(len(replay_buffer)).bool()
|
||||
assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), (
|
||||
"Truncateds from converted buffer should be equal False"
|
||||
)
|
||||
|
||||
assert torch.equal(
|
||||
replay_buffer.states["observation.state"], reconverted_buffer.states["observation.state"]
|
||||
replay_buffer.states["observation.state"][: len(replay_buffer)],
|
||||
reconverted_buffer.states["observation.state"][: len(replay_buffer)],
|
||||
), "State should be the same after converting to dataset and return back"
|
||||
|
||||
for i in range(4):
|
||||
|
||||
Reference in New Issue
Block a user