remove threefry setting in jax 0.5.0 (#279)
continuation conversation from: https://app.graphite.dev/github/pr/Physical-Intelligence/monopi/6672/upgrade-jax-to-0-5-0?utm_source=gt-slack-notif&panel=timeline#comment-PRRC_kwDOLnRTkc50DYoL
This commit is contained in:
@@ -199,7 +199,6 @@ def main(config: _config.TrainConfig):
|
||||
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
|
||||
)
|
||||
|
||||
jax.config.update("jax_threefry_partitionable", True) # noqa: FBT003
|
||||
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
|
||||
|
||||
rng = jax.random.key(config.seed)
|
||||
|
||||
Reference in New Issue
Block a user