remove threefry setting in jax 0.5.0
This commit is contained in:
@@ -199,7 +199,6 @@ def main(config: _config.TrainConfig):
|
||||
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
|
||||
)
|
||||
|
||||
jax.config.update("jax_threefry_partitionable", True) # noqa: FBT003
|
||||
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
|
||||
|
||||
rng = jax.random.key(config.seed)
|
||||
|
||||
Reference in New Issue
Block a user