Use data parallel sharding by default (#267)
Our model expects that and so this is a reasonable default to use out of the box.
This commit is contained in:
@@ -214,7 +214,12 @@ class TorchDataLoader:
|
|||||||
raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).")
|
raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).")
|
||||||
|
|
||||||
if sharding is None:
|
if sharding is None:
|
||||||
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
# Use data parallel sharding by default.
|
||||||
|
sharding = jax.sharding.NamedSharding(
|
||||||
|
jax.sharding.Mesh(jax.devices(), ("B",)),
|
||||||
|
jax.sharding.PartitionSpec("B"),
|
||||||
|
)
|
||||||
|
|
||||||
self._sharding = sharding
|
self._sharding = sharding
|
||||||
self._num_batches = num_batches
|
self._num_batches = num_batches
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user