Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion checkpoint/orbax/checkpoint/_src/metadata/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,13 @@ def __repr__(self):
return f'SingleDeviceShardingMetadata(device_str={self.device_str})'

def __eq__(self, other):
return self.device_str == other.device_str
if not isinstance(other, SingleDeviceShardingMetadata):
return False
# JAX 0.10 changed CPU devices so they report as cpu:0 not TFRT_CPU_0
return (
self.device_str.replace('TFRT_CPU_', 'cpu:')
== other.device_str.replace('TFRT_CPU_', 'cpu:')
)


def from_jax_sharding(jax_sharding) -> Optional[ShardingMetadata]:
Expand Down
12 changes: 2 additions & 10 deletions checkpoint/orbax/checkpoint/_src/metadata/sharding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,7 @@ def test_convert_between_jax_single_device_sharding_and_sharding_metadata(
jax_sharding = jax.sharding.SingleDeviceSharding(
jax.local_devices(backend="cpu")[0]
)
# JAX used to report its cpu devices as TFRT_CPU_0
expected_single_device_sharding_metadata = (
sharding_metadata.SingleDeviceShardingMetadata(device_str="TFRT_CPU_0")
)
# ... but now uses cpu:0
expected_single_device_sharding_metadata2 = (
sharding_metadata.SingleDeviceShardingMetadata(device_str="cpu:0")
)
converted_single_device_sharding_metadata = (
Expand All @@ -109,12 +104,9 @@ def test_convert_between_jax_single_device_sharding_and_sharding_metadata(
converted_single_device_sharding_metadata,
sharding_metadata.SingleDeviceShardingMetadata,
)
self.assertIn(
self.assertEqual(
converted_single_device_sharding_metadata,
[
expected_single_device_sharding_metadata,
expected_single_device_sharding_metadata2,
],
expected_single_device_sharding_metadata,
)

# Convert from `SingleDeviceShardingMetadata` to
Expand Down
Loading