remove threefry setting in jax 0.5.0

This commit is contained in:
Haohuan Wang
2025-02-07 21:27:02 +00:00
parent 007e2b91ed
commit ed05e55074

View File

@@ -199,7 +199,6 @@ def main(config: _config.TrainConfig):
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
)
jax.config.update("jax_threefry_partitionable", True) # noqa: FBT003
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
rng = jax.random.key(config.seed)