From ed05e55074ed8ee58cad8096b837ed582849c929 Mon Sep 17 00:00:00 2001 From: Haohuan Wang Date: Fri, 7 Feb 2025 21:27:02 +0000 Subject: [PATCH] remove threefry setting in jax 0.5.0 --- scripts/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/train.py b/scripts/train.py index cd86818..2792156 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -199,7 +199,6 @@ def main(config: _config.TrainConfig): f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}." ) - jax.config.update("jax_threefry_partitionable", True) # noqa: FBT003 jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser())) rng = jax.random.key(config.seed)