Remove random sampling

This commit is contained in:
Alexander Soare
2024-04-02 16:52:38 +01:00
parent 95293d459d
commit a6edb85da4
2 changed files with 10 additions and 12 deletions

View File

@@ -53,7 +53,9 @@ def test_compute_stats():
batch_size=len(buffer),
sampler=SamplerWithoutReplacement(),
).sample().float()
computed_stats = buffer._compute_stats()
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics.
computed_stats = buffer._compute_stats(batch_size=int(len(all_data) * 0.75))
for k, pattern in buffer.stats_patterns.items():
expected_mean = einops.reduce(all_data[k], pattern, "mean")
assert torch.allclose(computed_stats[k]["mean"], expected_mean)