remove try-catch

This commit is contained in:
Alexander Soare
2024-04-03 09:53:15 +01:00
parent c50a62dd6d
commit a6ec4fbf58

View File

@@ -58,12 +58,9 @@ def test_compute_stats():
for k, pattern in buffer.stats_patterns.items(): for k, pattern in buffer.stats_patterns.items():
expected_mean = einops.reduce(all_data[k], pattern, "mean") expected_mean = einops.reduce(all_data[k], pattern, "mean")
assert torch.allclose(computed_stats[k]["mean"], expected_mean) assert torch.allclose(computed_stats[k]["mean"], expected_mean)
try: assert torch.allclose(
assert torch.allclose( computed_stats[k]["std"],
computed_stats[k]["std"], torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean"))
torch.sqrt(einops.reduce((all_data[k] - expected_mean) ** 2, pattern, "mean")) )
)
except:
breakpoint()
assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min")) assert torch.allclose(computed_stats[k]["min"], einops.reduce(all_data[k], pattern, "min"))
assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max")) assert torch.allclose(computed_stats[k]["max"], einops.reduce(all_data[k], pattern, "max"))