[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
cdcf346061
commit
1c8daf11fd
@@ -176,7 +176,9 @@ def test_delta_timestamps_within_tolerance():
|
||||
buffer.tolerance_s = 0.04
|
||||
item = buffer[2]
|
||||
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
|
||||
torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values")
|
||||
torch.testing.assert_close(
|
||||
data, torch.tensor([0, 2, 3]), msg="Data does not match expected values"
|
||||
)
|
||||
assert not is_pad.any(), "Unexpected padding detected"
|
||||
|
||||
|
||||
@@ -212,7 +214,9 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
||||
buffer.tolerance_s = 0.04
|
||||
item = buffer[2]
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), (
|
||||
"Data does not match expected values"
|
||||
)
|
||||
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
|
||||
"Padding does not match expected values"
|
||||
)
|
||||
@@ -275,7 +279,8 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p
|
||||
online_sampling_ratio=online_sampling_ratio,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
|
||||
weights,
|
||||
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
|
||||
)
|
||||
|
||||
|
||||
@@ -297,7 +302,8 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
|
||||
online_drop_n_last_frames=1,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
|
||||
weights,
|
||||
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]),
|
||||
)
|
||||
|
||||
|
||||
@@ -318,4 +324,6 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp
|
||||
online_sampling_ratio=0.5,
|
||||
online_drop_n_last_frames=1,
|
||||
)
|
||||
torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
|
||||
torch.testing.assert_close(
|
||||
weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user