diff --git a/src/openpi/training/data_loader.py b/src/openpi/training/data_loader.py index a7f7326..2fc57c1 100644 --- a/src/openpi/training/data_loader.py +++ b/src/openpi/training/data_loader.py @@ -214,7 +214,12 @@ class TorchDataLoader: raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).") if sharding is None: - sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) + # Use data parallel sharding by default. + sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(jax.devices(), ("B",)), + jax.sharding.PartitionSpec("B"), + ) + self._sharding = sharding self._num_batches = num_batches