diff --git a/scripts/train.py b/scripts/train.py index cd86818..2792156 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -199,7 +199,6 @@ def main(config: _config.TrainConfig): f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}." ) - jax.config.update("jax_threefry_partitionable", True) # noqa: FBT003 jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser())) rng = jax.random.key(config.seed)