From f543cb1d87ada91645b1484ca5dab6ac3749be8d Mon Sep 17 00:00:00 2001 From: uzhilinsky Date: Wed, 5 Feb 2025 23:15:24 -0800 Subject: [PATCH] Use data parallel sharding by default (#267) Our model expects that and so this is a reasonable default to use out of the box. --- src/openpi/training/data_loader.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index a7f7326..2fc57c1 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -214,7 +214,12 @@ class TorchDataLoader: raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).") if sharding is None: - sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) + # Use data parallel sharding by default. + sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(jax.devices(), ("B",)), + jax.sharding.PartitionSpec("B"), + ) + self._sharding = sharding self._num_batches = num_batches