From 548cd6b5533305e70b0035f2f47a753f2a2cf313 Mon Sep 17 00:00:00 2001 From: LincolnBurrows2017 Date: Fri, 13 Mar 2026 20:44:47 +0000 Subject: [PATCH 1/4] fix: Support PyTorch 2.9+ return_aux parameter in flex_attention --- .../integrations/flex_attention.py | 32 +++++++++++++------ src/transformers/models/doge/modeling_doge.py | 30 +++++++++++++---- src/transformers/models/doge/modular_doge.py | 30 +++++++++++++---- 3 files changed, 69 insertions(+), 23 deletions(-) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 10737a984225..9727269e775c 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -282,24 +282,38 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): # On CPU we must skip returning LSE due to a runtime issue; elsewhere, follow PyTorch API and return it return_lse = query.device.type != "cpu" + # PyTorch >= 2.9 renamed return_lse to return_aux + torch_version = get_torch_version() + use_return_aux = version.parse(torch_version).base_version >= "2.9" + if not return_lse and s_aux is not None: raise ValueError( "Attention sinks cannot be run on CPU with flex attention. Please switch to a different device, e.g. CUDA" ) + # Build the kwargs for flex attention + flex_attn_kwargs = { + "score_mod": score_mod, + "block_mask": block_mask, + "enable_gqa": enable_gqa, + "scale": scaling, + "kernel_options": kernel_options, + "training": module.training, + } + + # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. + # For simplification, we thus always return it as no additional computations are introduced. + # In PyTorch >= 2.9, return_lse was renamed to return_aux + if use_return_aux: + flex_attn_kwargs["return_aux"] = return_lse + else: + flex_attn_kwargs["return_lse"] = return_lse + flex_attention_output = compile_friendly_flex_attention( query, key, value, - score_mod=score_mod, - block_mask=block_mask, - enable_gqa=enable_gqa, - scale=scaling, - kernel_options=kernel_options, - # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. - # For simplification, we thus always return it as no additional computations are introduced. - return_lse=return_lse, - training=module.training, + **flex_attn_kwargs, ) # lse is returned in float32 if return_lse: diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index 4aad59b52a9a..6f2564e55dc5 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -34,6 +34,8 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub from ...integrations.flex_attention import compile_friendly_flex_attention +from ...utils.import_utils import get_torch_version +from packaging import version from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast @@ -233,17 +235,31 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx] return score + # PyTorch >= 2.9 renamed return_lse to return_aux + torch_version = get_torch_version() + use_return_aux = version.parse(torch_version).base_version >= "2.9" + + # Build kwargs for flex attention + flex_attn_kwargs = { + "score_mod": score_mod, + "block_mask": block_mask, + "enable_gqa": True, + "scale": scaling, + } + + # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. + # For simplification, we thus always return it as no additional computations are introduced. + # In PyTorch >= 2.9, return_lse was renamed to return_aux + if use_return_aux: + flex_attn_kwargs["return_aux"] = True + else: + flex_attn_kwargs["return_lse"] = True + attn_output, attention_weights = compile_friendly_flex_attention( query, key, value, - score_mod=score_mod, - block_mask=block_mask, - enable_gqa=True, - scale=scaling, - # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. - # For simplification, we thus always return it as no additional computations are introduced. - return_lse=True, + **flex_attn_kwargs, ) # lse is returned in float32 attention_weights = attention_weights.to(value.dtype) diff --git a/src/transformers/models/doge/modular_doge.py b/src/transformers/models/doge/modular_doge.py index e1ca0b071fd1..b3826903f78d 100644 --- a/src/transformers/models/doge/modular_doge.py +++ b/src/transformers/models/doge/modular_doge.py @@ -28,6 +28,8 @@ from ...cache_utils import Cache from ...configuration_utils import PreTrainedConfig from ...integrations.flex_attention import compile_friendly_flex_attention +from ...utils.import_utils import get_torch_version +from packaging import version from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import RopeParameters @@ -202,17 +204,31 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx] return score + # PyTorch >= 2.9 renamed return_lse to return_aux + torch_version = get_torch_version() + use_return_aux = version.parse(torch_version).base_version >= "2.9" + + # Build kwargs for flex attention + flex_attn_kwargs = { + "score_mod": score_mod, + "block_mask": block_mask, + "enable_gqa": True, + "scale": scaling, + } + + # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. + # For simplification, we thus always return it as no additional computations are introduced. + # In PyTorch >= 2.9, return_lse was renamed to return_aux + if use_return_aux: + flex_attn_kwargs["return_aux"] = True + else: + flex_attn_kwargs["return_lse"] = True + attn_output, attention_weights = compile_friendly_flex_attention( query, key, value, - score_mod=score_mod, - block_mask=block_mask, - enable_gqa=True, - scale=scaling, - # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. - # For simplification, we thus always return it as no additional computations are introduced. - return_lse=True, + **flex_attn_kwargs, ) # lse is returned in float32 attention_weights = attention_weights.to(value.dtype) From bab54e74e1c01ad6381bca8cc5a5b51ab6598173 Mon Sep 17 00:00:00 2001 From: LincolnBurrows2017 Date: Sat, 14 Mar 2026 00:40:36 +0000 Subject: [PATCH 2/4] fix: Propagate num_labels to text_config in Qwen models --- .../models/qwen2_vl/configuration_qwen2_vl.py | 24 +++++++++++++++++++ .../models/qwen3_5/modular_qwen3_5.py | 24 +++++++++++++++++++ .../models/qwen3_vl/configuration_qwen3_vl.py | 24 +++++++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py index 243133267732..da22da366d3a 100644 --- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py @@ -230,6 +230,30 @@ def __init__( self.vision_end_token_id = vision_end_token_id self.tie_word_embeddings = tie_word_embeddings super().__init__(**kwargs) + # Propagate classification-related attributes to text_config for sequence classification tasks + # The text config is used for the actual model forward pass in tasks like sequence classification + self._propagate_classification_config_to_text_config() + + def _propagate_classification_config_to_text_config(self): + """Propagate classification-related attributes to text_config.""" + if hasattr(self, "text_config") and self.text_config is not None: + if hasattr(self, "num_labels") and self.num_labels is not None: + self.text_config.num_labels = self.num_labels + if hasattr(self, "id2label") and self.id2label is not None: + self.text_config.id2label = self.id2label.copy() + if hasattr(self, "label2id") and self.label2id is not None: + self.text_config.label2id = self.label2id.copy() + + @property + def num_labels(self) -> int: + return super().num_labels + + @num_labels.setter + def num_labels(self, num_labels: int): + # Call the parent setter first to set id2label + super(num_labels, type(self)).num_labels.fset(self, num_labels) # type: ignore[attr-defined] + # Then propagate to text_config + self._propagate_classification_config_to_text_config() __all__ = ["Qwen2VLConfig", "Qwen2VLTextConfig"] diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index f679eafc322f..14d776f45077 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -210,6 +210,30 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) + # Propagate classification-related attributes to text_config for sequence classification tasks + # The text config is used for the actual model forward pass in tasks like sequence classification + self._propagate_classification_config_to_text_config() + + def _propagate_classification_config_to_text_config(self): + """Propagate classification-related attributes to text_config.""" + if hasattr(self, "text_config") and self.text_config is not None: + if hasattr(self, "num_labels") and self.num_labels is not None: + self.text_config.num_labels = self.num_labels + if hasattr(self, "id2label") and self.id2label is not None: + self.text_config.id2label = self.id2label.copy() + if hasattr(self, "label2id") and self.label2id is not None: + self.text_config.label2id = self.label2id.copy() + + @property + def num_labels(self) -> int: + return super().num_labels + + @num_labels.setter + def num_labels(self, num_labels: int): + # Call the parent setter first to set id2label + super(num_labels, type(self)).num_labels.fset(self, num_labels) # type: ignore[attr-defined] + # Then propagate to text_config + self._propagate_classification_config_to_text_config() class Qwen3_5DynamicCache(Qwen3NextDynamicCache): diff --git a/src/transformers/models/qwen3_vl/configuration_qwen3_vl.py b/src/transformers/models/qwen3_vl/configuration_qwen3_vl.py index 69a411fe0783..a6faa684bf6f 100644 --- a/src/transformers/models/qwen3_vl/configuration_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/configuration_qwen3_vl.py @@ -189,6 +189,30 @@ def __init__( self.vision_end_token_id = vision_end_token_id self.tie_word_embeddings = tie_word_embeddings super().__init__(**kwargs) + # Propagate classification-related attributes to text_config for sequence classification tasks + # The text config is used for the actual model forward pass in tasks like sequence classification + self._propagate_classification_config_to_text_config() + + def _propagate_classification_config_to_text_config(self): + """Propagate classification-related attributes to text_config.""" + if hasattr(self, "text_config") and self.text_config is not None: + if hasattr(self, "num_labels") and self.num_labels is not None: + self.text_config.num_labels = self.num_labels + if hasattr(self, "id2label") and self.id2label is not None: + self.text_config.id2label = self.id2label.copy() + if hasattr(self, "label2id") and self.label2id is not None: + self.text_config.label2id = self.label2id.copy() + + @property + def num_labels(self) -> int: + return super().num_labels + + @num_labels.setter + def num_labels(self, num_labels: int): + # Call the parent setter first to set id2label + super(num_labels, type(self)).num_labels.fset(self, num_labels) # type: ignore[attr-defined] + # Then propagate to text_config + self._propagate_classification_config_to_text_config() __all__ = ["Qwen3VLConfig", "Qwen3VLTextConfig"] From 7dfc9163de2f3fdd03b67b13b3e296e0e3aab3b8 Mon Sep 17 00:00:00 2001 From: LincolnBurrows2017 Date: Sat, 14 Mar 2026 05:42:42 +0000 Subject: [PATCH 3/4] fix: Propagate num_labels to text_config in Qwen3.5 --- src/transformers/models/qwen3_5/configuration_qwen3_5.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/qwen3_5/configuration_qwen3_5.py b/src/transformers/models/qwen3_5/configuration_qwen3_5.py index a2b739c6629d..582aaaeb3cea 100644 --- a/src/transformers/models/qwen3_5/configuration_qwen3_5.py +++ b/src/transformers/models/qwen3_5/configuration_qwen3_5.py @@ -227,7 +227,12 @@ def __init__( if isinstance(text_config, dict): self.text_config = self.sub_configs["text_config"](**text_config) elif text_config is None: - self.text_config = self.sub_configs["text_config"]() + # Pop classification-related parameters to propagate to text_config + text_config_kwargs = {} + for key in ["num_labels", "id2label", "label2id"]: + if key in kwargs: + text_config_kwargs[key] = kwargs.pop(key) + self.text_config = self.sub_configs["text_config"](**text_config_kwargs) self.image_token_id = image_token_id self.video_token_id = video_token_id From 550a9560e12e1025da5b30c9986c0abdd0f04eed Mon Sep 17 00:00:00 2001 From: LincolnBurrows2017 Date: Sat, 14 Mar 2026 14:39:26 +0000 Subject: [PATCH 4/4] fix: Correct rms_norm_eps type hint from int to float in MistralConfig The rms_norm_eps parameter was incorrectly typed as int | None but defaults to 1e-6 which is a float. This parameter is passed to MistralRMSNorm which expects eps: float. This is a type hint bug that affects 59+ model configurations in the codebase. --- src/transformers/models/mistral/configuration_mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mistral/configuration_mistral.py b/src/transformers/models/mistral/configuration_mistral.py index 7bc24445ff90..a232496b30b5 100644 --- a/src/transformers/models/mistral/configuration_mistral.py +++ b/src/transformers/models/mistral/configuration_mistral.py @@ -69,7 +69,7 @@ def __init__( hidden_act: str | None = "silu", max_position_embeddings: int | None = 4096 * 32, initializer_range: float | None = 0.02, - rms_norm_eps: int | None = 1e-6, + rms_norm_eps: float | None = 1e-6, use_cache: bool | None = True, pad_token_id: int | None = None, bos_token_id: int | None = 1,