diff --git a/checkpoint/orbax/checkpoint/_src/metadata/sharding.py b/checkpoint/orbax/checkpoint/_src/metadata/sharding.py index a6ca18ebb..55e31c9b7 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/sharding.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/sharding.py @@ -325,7 +325,13 @@ def __repr__(self): return f'SingleDeviceShardingMetadata(device_str={self.device_str})' def __eq__(self, other): - return self.device_str == other.device_str + if not isinstance(other, SingleDeviceShardingMetadata): + return False + # JAX 0.10 changed CPU devices so they report as cpu:0 not TFRT_CPU_0 + return ( + self.device_str.replace('TFRT_CPU_', 'cpu:') + == other.device_str.replace('TFRT_CPU_', 'cpu:') + ) def from_jax_sharding(jax_sharding) -> Optional[ShardingMetadata]: diff --git a/checkpoint/orbax/checkpoint/_src/metadata/sharding_test.py b/checkpoint/orbax/checkpoint/_src/metadata/sharding_test.py index 8afde9e70..bafb21a69 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/sharding_test.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/sharding_test.py @@ -93,12 +93,7 @@ def test_convert_between_jax_single_device_sharding_and_sharding_metadata( jax_sharding = jax.sharding.SingleDeviceSharding( jax.local_devices(backend="cpu")[0] ) - # JAX used to report its cpu devices as TFRT_CPU_0 expected_single_device_sharding_metadata = ( - sharding_metadata.SingleDeviceShardingMetadata(device_str="TFRT_CPU_0") - ) - # ... but now uses cpu:0 - expected_single_device_sharding_metadata2 = ( sharding_metadata.SingleDeviceShardingMetadata(device_str="cpu:0") ) converted_single_device_sharding_metadata = ( @@ -109,12 +104,9 @@ def test_convert_between_jax_single_device_sharding_and_sharding_metadata( converted_single_device_sharding_metadata, sharding_metadata.SingleDeviceShardingMetadata, ) - self.assertIn( + self.assertEqual( converted_single_device_sharding_metadata, - [ - expected_single_device_sharding_metadata, - expected_single_device_sharding_metadata2, - ], + expected_single_device_sharding_metadata, ) # Convert from `SingleDeviceShardingMetadata` to