Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions src/transformers/integrations/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,24 +282,38 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
# On CPU we must skip returning LSE due to a runtime issue; elsewhere, follow PyTorch API and return it
return_lse = query.device.type != "cpu"

# PyTorch >= 2.9 renamed return_lse to return_aux
torch_version = get_torch_version()
use_return_aux = version.parse(torch_version).base_version >= "2.9"

if not return_lse and s_aux is not None:
raise ValueError(
"Attention sinks cannot be run on CPU with flex attention. Please switch to a different device, e.g. CUDA"
)

# Build the kwargs for flex attention
flex_attn_kwargs = {
"score_mod": score_mod,
"block_mask": block_mask,
"enable_gqa": enable_gqa,
"scale": scaling,
"kernel_options": kernel_options,
"training": module.training,
}

# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
# In PyTorch >= 2.9, return_lse was renamed to return_aux
if use_return_aux:
flex_attn_kwargs["return_aux"] = return_lse
else:
flex_attn_kwargs["return_lse"] = return_lse

flex_attention_output = compile_friendly_flex_attention(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=enable_gqa,
scale=scaling,
kernel_options=kernel_options,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=return_lse,
training=module.training,
**flex_attn_kwargs,
)
# lse is returned in float32
if return_lse:
Expand Down
30 changes: 23 additions & 7 deletions src/transformers/models/doge/modeling_doge.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
from ...integrations.flex_attention import compile_friendly_flex_attention
from ...utils.import_utils import get_torch_version
from packaging import version
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
Expand Down Expand Up @@ -233,17 +235,31 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx]
return score

# PyTorch >= 2.9 renamed return_lse to return_aux
torch_version = get_torch_version()
use_return_aux = version.parse(torch_version).base_version >= "2.9"

# Build kwargs for flex attention
flex_attn_kwargs = {
"score_mod": score_mod,
"block_mask": block_mask,
"enable_gqa": True,
"scale": scaling,
}

# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
# In PyTorch >= 2.9, return_lse was renamed to return_aux
if use_return_aux:
flex_attn_kwargs["return_aux"] = True
else:
flex_attn_kwargs["return_lse"] = True

attn_output, attention_weights = compile_friendly_flex_attention(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=True,
scale=scaling,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
**flex_attn_kwargs,
)
# lse is returned in float32
attention_weights = attention_weights.to(value.dtype)
Expand Down
30 changes: 23 additions & 7 deletions src/transformers/models/doge/modular_doge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from ...cache_utils import Cache
from ...configuration_utils import PreTrainedConfig
from ...integrations.flex_attention import compile_friendly_flex_attention
from ...utils.import_utils import get_torch_version
from packaging import version
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_rope_utils import RopeParameters
Expand Down Expand Up @@ -202,17 +204,31 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx]
return score

# PyTorch >= 2.9 renamed return_lse to return_aux
torch_version = get_torch_version()
use_return_aux = version.parse(torch_version).base_version >= "2.9"

# Build kwargs for flex attention
flex_attn_kwargs = {
"score_mod": score_mod,
"block_mask": block_mask,
"enable_gqa": True,
"scale": scaling,
}

# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
# In PyTorch >= 2.9, return_lse was renamed to return_aux
if use_return_aux:
flex_attn_kwargs["return_aux"] = True
else:
flex_attn_kwargs["return_lse"] = True

attn_output, attention_weights = compile_friendly_flex_attention(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=True,
scale=scaling,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
**flex_attn_kwargs,
)
# lse is returned in float32
attention_weights = attention_weights.to(value.dtype)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mistral/configuration_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
hidden_act: str | None = "silu",
max_position_embeddings: int | None = 4096 * 32,
initializer_range: float | None = 0.02,
rms_norm_eps: int | None = 1e-6,
rms_norm_eps: float | None = 1e-6,
use_cache: bool | None = True,
pad_token_id: int | None = None,
bos_token_id: int | None = 1,
Expand Down
24 changes: 24 additions & 0 deletions src/transformers/models/qwen2_vl/configuration_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,30 @@ def __init__(
self.vision_end_token_id = vision_end_token_id
self.tie_word_embeddings = tie_word_embeddings
super().__init__(**kwargs)
# Propagate classification-related attributes to text_config for sequence classification tasks
# The text config is used for the actual model forward pass in tasks like sequence classification
self._propagate_classification_config_to_text_config()

def _propagate_classification_config_to_text_config(self):
"""Propagate classification-related attributes to text_config."""
if hasattr(self, "text_config") and self.text_config is not None:
if hasattr(self, "num_labels") and self.num_labels is not None:
self.text_config.num_labels = self.num_labels
if hasattr(self, "id2label") and self.id2label is not None:
self.text_config.id2label = self.id2label.copy()
if hasattr(self, "label2id") and self.label2id is not None:
self.text_config.label2id = self.label2id.copy()

@property
def num_labels(self) -> int:
return super().num_labels

@num_labels.setter
def num_labels(self, num_labels: int):
# Call the parent setter first to set id2label
super(num_labels, type(self)).num_labels.fset(self, num_labels) # type: ignore[attr-defined]
# Then propagate to text_config
self._propagate_classification_config_to_text_config()


__all__ = ["Qwen2VLConfig", "Qwen2VLTextConfig"]
7 changes: 6 additions & 1 deletion src/transformers/models/qwen3_5/configuration_qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,12 @@ def __init__(
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif text_config is None:
self.text_config = self.sub_configs["text_config"]()
# Pop classification-related parameters to propagate to text_config
text_config_kwargs = {}
for key in ["num_labels", "id2label", "label2id"]:
if key in kwargs:
text_config_kwargs[key] = kwargs.pop(key)
self.text_config = self.sub_configs["text_config"](**text_config_kwargs)

self.image_token_id = image_token_id
self.video_token_id = video_token_id
Expand Down
24 changes: 24 additions & 0 deletions src/transformers/models/qwen3_5/modular_qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,30 @@ def __init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# Propagate classification-related attributes to text_config for sequence classification tasks
# The text config is used for the actual model forward pass in tasks like sequence classification
self._propagate_classification_config_to_text_config()

def _propagate_classification_config_to_text_config(self):
"""Propagate classification-related attributes to text_config."""
if hasattr(self, "text_config") and self.text_config is not None:
if hasattr(self, "num_labels") and self.num_labels is not None:
self.text_config.num_labels = self.num_labels
if hasattr(self, "id2label") and self.id2label is not None:
self.text_config.id2label = self.id2label.copy()
if hasattr(self, "label2id") and self.label2id is not None:
self.text_config.label2id = self.label2id.copy()

@property
def num_labels(self) -> int:
return super().num_labels

@num_labels.setter
def num_labels(self, num_labels: int):
# Call the parent setter first to set id2label
super(num_labels, type(self)).num_labels.fset(self, num_labels) # type: ignore[attr-defined]
# Then propagate to text_config
self._propagate_classification_config_to_text_config()


class Qwen3_5DynamicCache(Qwen3NextDynamicCache):
Expand Down
24 changes: 24 additions & 0 deletions src/transformers/models/qwen3_vl/configuration_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,30 @@ def __init__(
self.vision_end_token_id = vision_end_token_id
self.tie_word_embeddings = tie_word_embeddings
super().__init__(**kwargs)
# Propagate classification-related attributes to text_config for sequence classification tasks
# The text config is used for the actual model forward pass in tasks like sequence classification
self._propagate_classification_config_to_text_config()

def _propagate_classification_config_to_text_config(self):
"""Propagate classification-related attributes to text_config."""
if hasattr(self, "text_config") and self.text_config is not None:
if hasattr(self, "num_labels") and self.num_labels is not None:
self.text_config.num_labels = self.num_labels
if hasattr(self, "id2label") and self.id2label is not None:
self.text_config.id2label = self.id2label.copy()
if hasattr(self, "label2id") and self.label2id is not None:
self.text_config.label2id = self.label2id.copy()

@property
def num_labels(self) -> int:
return super().num_labels

@num_labels.setter
def num_labels(self, num_labels: int):
# Call the parent setter first to set id2label
super(num_labels, type(self)).num_labels.fset(self, num_labels) # type: ignore[attr-defined]
# Then propagate to text_config
self._propagate_classification_config_to_text_config()


__all__ = ["Qwen3VLConfig", "Qwen3VLTextConfig"]
Loading