From 0239f6b17e69ba7fd4406a691350fd4dbb8d2057 Mon Sep 17 00:00:00 2001 From: Gagik Amirkhanyan Date: Tue, 31 Mar 2026 10:03:12 -0700 Subject: [PATCH] Fix Kimi-k2 checkpoint conversion --- .../convert_deepseek_family_unscanned_ckpt.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py index 63e73145a8..2ec98f6004 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py @@ -70,10 +70,11 @@ def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info) if key.endswith("_scale_inv"): raise ValueError("fp8 checkpoint is not supported.") if ds_ckpt.is_key_allowed(key, ds_ckpt.MTP_KEYS_TO_SKIP): - mapped_key = ds_ckpt.hf_to_maxtext_mapping(layer, num_experts, first_num_dense_layers, base_num_decoder_layers)[ - key - ] - chkpt_vars[mapped_key] = f.get_tensor(key) + mapped_key = ds_ckpt.hf_to_maxtext_mapping( + layer, num_experts, first_num_dense_layers, base_num_decoder_layers + ).get(key) + if mapped_key: + chkpt_vars[mapped_key] = f.get_tensor(key) logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))