Use data parallel sharding by default (#267)
Our model expects that and so this is a reasonable default to use out of the box.
This commit is contained in:
@@ -214,7 +214,12 @@ class TorchDataLoader:
|
||||
raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).")
|
||||
|
||||
if sharding is None:
|
||||
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
||||
# Use data parallel sharding by default.
|
||||
sharding = jax.sharding.NamedSharding(
|
||||
jax.sharding.Mesh(jax.devices(), ("B",)),
|
||||
jax.sharding.PartitionSpec("B"),
|
||||
)
|
||||
|
||||
self._sharding = sharding
|
||||
self._num_batches = num_batches
|
||||
|
||||
|
||||
Reference in New Issue
Block a user