diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index ea65ea5fed..ee510123e3 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2190,7 +2190,7 @@ def get_num_target_devices(): # Default quantization sharding count to number of local devices if not set. if self.quantization_local_shard_count == -1: try: - self.quantization_local_shard_count = jax.local_device_count() + self.quantization_local_shard_count = 1 + max(d.slice_index for d in jax.devices()) except RuntimeError: self.quantization_local_shard_count = 1