Use data parallel sharding by default (#267)

Our model expects that and so this is a reasonable default to use out of the box.
This commit is contained in:
uzhilinsky
2025-02-05 23:15:24 -08:00
committed by GitHub
parent 6104624aca
commit f543cb1d87

View File

@@ -214,7 +214,12 @@ class TorchDataLoader:
raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).")
if sharding is None:
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
# Use data parallel sharding by default.
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ("B",)),
jax.sharding.PartitionSpec("B"),
)
self._sharding = sharding
self._num_batches = num_batches