From ef0e4ada17441100e6a59a0b11352ecdfc17d295 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Sat, 14 Mar 2026 12:41:00 +0100 Subject: [PATCH 1/2] Fix Qwen2-VL based models' pipeline parallel support Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .../models/qwen2_5_omni/modeling_qwen2_5_omni.py | 9 ++++----- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 5 ++--- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 5 ++--- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index de8e0050be89..5bad4c4e945c 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1516,7 +1516,6 @@ def __init__(self, config: Qwen2_5OmniTextConfig, layer_idx: int): self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2_5OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2_5OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -1667,10 +1666,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=text_position_ids, past_key_values=past_key_values, @@ -2187,10 +2186,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=text_position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 6e3acc5466e7..ed426d5a80d2 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -774,7 +774,6 @@ def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int): self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2_5_VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2_5_VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -924,10 +923,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=text_position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 2ecb7773581b..5a51d5ac8921 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -609,7 +609,6 @@ def __init__(self, config: Qwen2VLTextConfig, layer_idx: int): self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -892,10 +891,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=text_position_ids, past_key_values=past_key_values, From 8d99b6ef557de1a4ebc3f33cc53907be5165387e Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Sat, 14 Mar 2026 14:07:39 +0100 Subject: [PATCH 2/2] Fix other models too Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .../models/afmoe/modeling_afmoe.py | 5 +- .../models/afmoe/modular_afmoe.py | 5 +- .../models/bamba/modeling_bamba.py | 4 +- .../models/bamba/modular_bamba.py | 4 +- .../models/cohere2/modeling_cohere2.py | 5 +- .../models/cohere2/modular_cohere2.py | 5 +- src/transformers/models/cwm/modeling_cwm.py | 5 +- src/transformers/models/cwm/modular_cwm.py | 5 +- .../models/dots1/configuration_dots1.py | 11 +- .../models/dots1/modeling_dots1.py | 5 +- .../models/dots1/modular_dots1.py | 139 +++++++++++++++++- .../models/gemma2/modeling_gemma2.py | 5 +- .../models/gemma2/modular_gemma2.py | 5 +- .../models/gemma3/modeling_gemma3.py | 7 +- .../models/gemma3/modular_gemma3.py | 7 +- .../models/gemma3n/modeling_gemma3n.py | 9 +- .../models/gemma3n/modular_gemma3n.py | 8 +- .../models/gpt_oss/modeling_gpt_oss.py | 5 +- .../models/gpt_oss/modular_gpt_oss.py | 5 +- .../modeling_granitemoehybrid.py | 4 +- .../modular_granitemoehybrid.py | 4 +- src/transformers/models/lfm2/modeling_lfm2.py | 4 +- src/transformers/models/lfm2/modular_lfm2.py | 4 +- .../models/lfm2_moe/modeling_lfm2_moe.py | 4 +- .../models/lfm2_moe/modular_lfm2_moe.py | 4 +- .../models/llama4/modeling_llama4.py | 5 +- .../models/minimax/modeling_minimax.py | 4 +- .../models/minimax/modular_minimax.py | 4 +- .../models/ministral/modeling_ministral.py | 5 +- .../models/ministral/modular_ministral.py | 4 +- .../modeling_modernbert_decoder.py | 7 +- .../modular_modernbert_decoder.py | 7 +- .../models/olmo3/modeling_olmo3.py | 4 +- .../models/olmo3/modular_olmo3.py | 4 +- .../olmo_hybrid/modeling_olmo_hybrid.py | 6 +- .../models/olmo_hybrid/modular_olmo_hybrid.py | 6 +- .../models/qwen2/modeling_qwen2.py | 5 +- .../models/qwen2/modular_qwen2.py | 8 +- .../models/qwen3/modeling_qwen3.py | 5 +- .../models/qwen3_5/modeling_qwen3_5.py | 2 +- .../models/qwen3_5/modular_qwen3_5.py | 2 +- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 2 +- .../models/qwen3_next/modeling_qwen3_next.py | 4 +- .../models/qwen3_next/modular_qwen3_next.py | 4 +- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 9 +- .../qwen3_omni_moe/modular_qwen3_omni_moe.py | 4 +- .../models/smollm3/modeling_smollm3.py | 5 +- .../models/smollm3/modular_smollm3.py | 4 +- .../models/vaultgemma/modeling_vaultgemma.py | 5 +- 49 files changed, 244 insertions(+), 139 deletions(-) diff --git a/src/transformers/models/afmoe/modeling_afmoe.py b/src/transformers/models/afmoe/modeling_afmoe.py index 7fb24d347c00..35642c56a411 100644 --- a/src/transformers/models/afmoe/modeling_afmoe.py +++ b/src/transformers/models/afmoe/modeling_afmoe.py @@ -439,7 +439,6 @@ def __init__(self, config: AfmoeConfig, layer_idx: int): self.layer_idx = layer_idx self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx) - self.attention_type = config.layer_types[layer_idx] # Dual normalization for attention self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -603,10 +602,10 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_value=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/afmoe/modular_afmoe.py b/src/transformers/models/afmoe/modular_afmoe.py index 5fd37eb57e07..6d1b1338b4f0 100644 --- a/src/transformers/models/afmoe/modular_afmoe.py +++ b/src/transformers/models/afmoe/modular_afmoe.py @@ -261,7 +261,6 @@ def __init__(self, config: AfmoeConfig, layer_idx: int): self.layer_idx = layer_idx self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx) - self.attention_type = config.layer_types[layer_idx] # Dual normalization for attention self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -425,10 +424,10 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_value=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 4d2bcf028b07..8888fa5ddbdf 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1164,8 +1164,8 @@ def forward( mamba_mask = self._update_mamba_mask(attention_mask, past_key_values) position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) - for decoder_layer in self.layers: - layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask + for i, decoder_layer in enumerate(self.layers): + layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask hidden_states, attn_weights = decoder_layer( hidden_states, diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index c9370dd0f986..4961025f1743 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -839,8 +839,8 @@ def forward( mamba_mask = self._update_mamba_mask(attention_mask, past_key_values) position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) - for decoder_layer in self.layers: - layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask + for i, decoder_layer in enumerate(self.layers): + layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask hidden_states, attn_weights = decoder_layer( hidden_states, diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index c2597b872b72..f43b2a0ef412 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -290,7 +290,6 @@ def __init__(self, config: Cohere2Config, layer_idx: int): self.self_attn = Cohere2Attention(config=config, layer_idx=layer_idx) self.mlp = Cohere2MLP(config) self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -413,10 +412,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, past_key_values=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index ceff0d1db06f..a2e6bfa75f18 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -248,7 +248,6 @@ def forward( class Cohere2DecoderLayer(CohereDecoderLayer): def __init__(self, config: Cohere2Config, layer_idx: int): super().__init__(config, layer_idx) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -329,10 +328,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, past_key_values=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/cwm/modeling_cwm.py b/src/transformers/models/cwm/modeling_cwm.py index 772076754eb2..3e0eb0504be0 100644 --- a/src/transformers/models/cwm/modeling_cwm.py +++ b/src/transformers/models/cwm/modeling_cwm.py @@ -285,7 +285,6 @@ def __init__(self, config: CwmConfig, layer_idx: int): self.mlp = CwmMLP(config) self.input_layernorm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -407,10 +406,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, position_embeddings=position_embeddings, diff --git a/src/transformers/models/cwm/modular_cwm.py b/src/transformers/models/cwm/modular_cwm.py index 0f002052f002..a372914365ce 100644 --- a/src/transformers/models/cwm/modular_cwm.py +++ b/src/transformers/models/cwm/modular_cwm.py @@ -133,7 +133,6 @@ def __init__(self, config: CwmConfig, layer_idx: int): class CwmDecoderLayer(LlamaDecoderLayer): def __init__(self, config: CwmConfig, layer_idx: int): super().__init__(config=config, layer_idx=layer_idx) - self.attention_type = config.layer_types[layer_idx] self.self_attn = CwmAttention(config=config, layer_idx=layer_idx) @@ -196,10 +195,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, position_embeddings=position_embeddings, diff --git a/src/transformers/models/dots1/configuration_dots1.py b/src/transformers/models/dots1/configuration_dots1.py index d9ea018ecdca..356e5f3726e1 100644 --- a/src/transformers/models/dots1/configuration_dots1.py +++ b/src/transformers/models/dots1/configuration_dots1.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/dots1/modular_dots1.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dots1.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,10 +20,7 @@ from ...configuration_utils import PreTrainedConfig, layer_type_validation from ...modeling_rope_utils import RopeParameters -from ...utils import auto_docstring, logging - - -logger = logging.get_logger(__name__) +from ...utils import auto_docstring @auto_docstring(checkpoint="rednote-hilab/dots.llm1.base") diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 998af195ecd5..399194648663 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -417,7 +417,6 @@ def __init__(self, config: Dots1Config, layer_idx: int): self.input_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -549,10 +548,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/dots1/modular_dots1.py b/src/transformers/models/dots1/modular_dots1.py index 189e28d7e37c..f9c113800e1c 100644 --- a/src/transformers/models/dots1/modular_dots1.py +++ b/src/transformers/models/dots1/modular_dots1.py @@ -13,9 +13,11 @@ # limitations under the License. import torch +from ...configuration_utils import PreTrainedConfig, layer_type_validation from ...modeling_outputs import CausalLMOutputWithPast +from ...modeling_rope_utils import RopeParameters from ...processing_utils import Unpack -from ...utils import logging +from ...utils import auto_docstring, logging from ..deepseek_v3.modeling_deepseek_v3 import ( DeepseekV3DecoderLayer, DeepseekV3MLP, @@ -31,12 +33,140 @@ Qwen3RotaryEmbedding, TransformersKwargs, ) -from .configuration_dots1 import Dots1Config logger = logging.get_logger(__name__) +@auto_docstring(checkpoint="rednote-hilab/dots.llm1.base") +class Dots1Config(PreTrainedConfig): + r""" + n_group (`int`, *optional*, defaults to 1): + Number of groups for routed experts. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers at the beginning of the model before the first MoE layer. + + Examples: + + ```python + >>> from transformers import Dots1Model, Dots1Config + >>> # Initializing a Dots1 style configuration + >>> configuration = Dots1Config() + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "dots1" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", + "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + attribute_map = { + "num_local_experts": "n_routed_experts", + } + + def __init__( + self, + vocab_size: int | None = 152064, + hidden_size: int | None = 4608, + intermediate_size: int | None = 10944, + moe_intermediate_size: int | None = 1408, + num_hidden_layers: int | None = 62, + num_attention_heads: int | None = 32, + num_key_value_heads: int | None = 32, + n_shared_experts: int | None = None, + n_routed_experts: int | None = None, + n_group: int | None = 1, + topk_group: int | None = 1, + num_experts_per_tok: int | None = None, + first_k_dense_replace: int | None = 0, + norm_topk_prob: bool | None = False, + hidden_act: str | None = "silu", + max_position_embeddings: int | None = 2048, + initializer_range: float | None = 0.02, + rms_norm_eps: int | None = 1e-6, + use_cache: bool | None = True, + tie_word_embeddings: bool | None = False, + rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, + attention_bias: bool | None = False, + attention_dropout: float | None = 0.0, + routed_scaling_factor: float | None = 1.0, + sliding_window: int | None = 4096, + max_window_layers: int | None = 62, + layer_types: list[str] | None = None, + pad_token_id: int | None = None, + bos_token_id: int | None = None, + eos_token_id: int | None = None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.num_experts_per_tok = num_experts_per_tok + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.n_group = n_group + self.topk_group = topk_group + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.routed_scaling_factor = routed_scaling_factor + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + self.tie_word_embeddings = tie_word_embeddings + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.rope_parameters = rope_parameters + + super().__init__(**kwargs) + + class Dots1RMSNorm(Qwen3RMSNorm): pass @@ -85,9 +215,7 @@ def route_tokens_to_experts(self, router_logits): class Dots1DecoderLayer(DeepseekV3DecoderLayer): - def __init__(self, config: Dots1Config, layer_idx: int): - super().__init__(config, layer_idx) - self.attention_type = config.layer_types[layer_idx] + pass class Dots1PreTrainedModel(DeepseekV3PreTrainedModel): @@ -129,6 +257,7 @@ def forward( __all__ = [ + "Dots1Config", "Dots1PreTrainedModel", "Dots1Model", "Dots1ForCausalLM", diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 63ee2874a4a4..247395eeab37 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -304,7 +304,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.config = config - self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -456,10 +455,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index a6c1c4e758c5..70e847655cbc 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -291,7 +291,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.config = config - self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -390,10 +389,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index a83e3332c1fb..3f5fad0d8f5a 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -394,7 +394,6 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int): self.config = config self.hidden_size = config.hidden_size self.layer_idx = layer_idx - self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma3MLP(config) self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) @@ -571,11 +570,11 @@ def forward( for layer_type in self.config.layer_types: position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_embeddings=position_embeddings[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], + position_embeddings=position_embeddings[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, **kwargs, diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 3ffcd97373cd..a04a19e58ffa 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -449,7 +449,6 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int): self.config = config self.hidden_size = config.hidden_size self.layer_idx = layer_idx - self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma3MLP(config) self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) @@ -600,11 +599,11 @@ def forward( for layer_type in self.config.layer_types: position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_embeddings=position_embeddings[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], + position_embeddings=position_embeddings[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, **kwargs, diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 27c9163e3e4b..2a9bd2e77b07 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1282,7 +1282,6 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int): self.config = config self.hidden_size = config.hidden_size self.layer_idx = layer_idx - self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma3nTextAttention(config, layer_idx) self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx) self.input_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) @@ -1700,13 +1699,13 @@ def forward( for layer_type in self.config.layer_types: position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - causal_mask = causal_mask_mapping[decoder_layer.attention_type] - per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :] + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + causal_mask = causal_mask_mapping[self.config.layer_types[i]] + per_layer_input = per_layer_inputs[:, :, i, :] hidden_states = decoder_layer( hidden_states, - position_embeddings[decoder_layer.attention_type], + position_embeddings[self.config.layer_types[i]], per_layer_input, attention_mask=causal_mask, position_ids=position_ids, diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index f55fead4560c..85c43b3a400d 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1987,13 +1987,13 @@ def forward( for layer_type in self.config.layer_types: position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - causal_mask = causal_mask_mapping[decoder_layer.attention_type] - per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :] + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + causal_mask = causal_mask_mapping[self.config.layer_types[i]] + per_layer_input = per_layer_inputs[:, :, i, :] hidden_states = decoder_layer( hidden_states, - position_embeddings[decoder_layer.attention_type], + position_embeddings[self.config.layer_types[i]], per_layer_input, attention_mask=causal_mask, position_ids=position_ids, diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 9157c32f1626..b730e1cf363d 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -353,7 +353,6 @@ def __init__(self, config: GptOssConfig, layer_idx: int): self.mlp = GptOssMLP(config) self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -487,10 +486,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index b2e6374dd1f2..9187fd1ebd82 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -278,7 +278,6 @@ def __init__(self, config: GptOssConfig, layer_idx: int): self.mlp = GptOssMLP(config) self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -384,10 +383,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 78d3d232bd5e..676ec8f93773 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1312,9 +1312,9 @@ def forward( if self.rotary_emb is not None: position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) - layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask + layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index ead20868aa92..4c72531bddb5 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -245,9 +245,9 @@ def forward( if self.rotary_emb is not None: position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) - layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask + layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index ef79261364ba..b995c5ed1bc5 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -674,8 +674,8 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) # decoder layers - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - layer_mask = causal_mask if decoder_layer.is_attention_layer else linear_attention + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = causal_mask if self.config.layer_types[i] == "full_attention" else linear_attention hidden_states = decoder_layer( hidden_states, attention_mask=layer_mask, diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index 65119a287abd..ce178fabb3f8 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -480,8 +480,8 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) # decoder layers - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - layer_mask = causal_mask if decoder_layer.is_attention_layer else linear_attention + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = causal_mask if self.config.layer_types[i] == "full_attention" else linear_attention hidden_states = decoder_layer( hidden_states, attention_mask=layer_mask, diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py index 03bdbbdc95f8..e9575cd255f8 100644 --- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -764,8 +764,8 @@ def forward( position_embeddings = self.pos_emb(hidden_states, position_ids=position_ids) # decoder layers - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - layer_mask = causal_mask if decoder_layer.is_attention_layer else linear_attention + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = causal_mask if self.config.layer_types[i] == "full_attention" else linear_attention hidden_states = decoder_layer( hidden_states, attention_mask=layer_mask, diff --git a/src/transformers/models/lfm2_moe/modular_lfm2_moe.py b/src/transformers/models/lfm2_moe/modular_lfm2_moe.py index e5847937ce35..7d97d8b70dd5 100644 --- a/src/transformers/models/lfm2_moe/modular_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modular_lfm2_moe.py @@ -195,8 +195,8 @@ def forward( position_embeddings = self.pos_emb(hidden_states, position_ids=position_ids) # decoder layers - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - layer_mask = causal_mask if decoder_layer.is_attention_layer else linear_attention + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = causal_mask if self.config.layer_types[i] == "full_attention" else linear_attention hidden_states = decoder_layer( hidden_states, attention_mask=layer_mask, diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index b04931be001b..edf8f3276747 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -415,7 +415,6 @@ def __init__(self, config, layer_idx): super().__init__() self.hidden_size = config.hidden_size self.layer_idx = layer_idx - self.attention_type = config.layer_types[layer_idx] self.self_attn = Llama4TextAttention(config, layer_idx) self.is_moe_layer = layer_idx in config.moe_layers if self.is_moe_layer: # the 128E model interleaves dense / sparse @@ -567,10 +566,10 @@ def forward( # create position embeddings to be shared across the decoder layers freq_cis = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index d6eaab3426c5..d6b6871bfe31 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -679,8 +679,8 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: - if decoder_layer.layer_type == "full_attention": + for i, decoder_layer in enumerate(self.layers): + if self.config.layer_types[i] == "full_attention": input_attention_mask = causal_mask else: # lightning attention uses original attention_mask, and uses it only for the first step diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index 56871bcfa18f..feb4fb4384de 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -502,8 +502,8 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: - if decoder_layer.layer_type == "full_attention": + for i, decoder_layer in enumerate(self.layers): + if self.config.layer_types[i] == "full_attention": input_attention_mask = causal_mask else: # lightning attention uses original attention_mask, and uses it only for the first step diff --git a/src/transformers/models/ministral/modeling_ministral.py b/src/transformers/models/ministral/modeling_ministral.py index d9856ca49694..d3eb723c9840 100644 --- a/src/transformers/models/ministral/modeling_ministral.py +++ b/src/transformers/models/ministral/modeling_ministral.py @@ -212,7 +212,6 @@ def __init__(self, config: MinistralConfig, layer_idx: int): self.mlp = MinistralMLP(config) self.input_layernorm = MinistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MinistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -394,10 +393,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/ministral/modular_ministral.py b/src/transformers/models/ministral/modular_ministral.py index 0c27cfb5316b..33c3dc9d2e4e 100644 --- a/src/transformers/models/ministral/modular_ministral.py +++ b/src/transformers/models/ministral/modular_ministral.py @@ -185,10 +185,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index 1cb956151c70..f61ff6089796 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -312,7 +312,6 @@ def __init__(self, config: ModernBertDecoderConfig, layer_idx: int | None = None super().__init__() self.config = config self.layer_idx = layer_idx - self.attention_type = config.layer_types[layer_idx] self.attn_norm = ( nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) if layer_idx != 0 @@ -508,11 +507,11 @@ def forward( for layer_type in self.config.layer_types: position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_embeddings=position_embeddings[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], + position_embeddings=position_embeddings[self.config.layer_types[i]], past_key_values=past_key_values, position_ids=position_ids, **kwargs, diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index f0c439847c9a..b5f603ae503f 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -325,7 +325,6 @@ def __init__(self, config: ModernBertDecoderConfig, layer_idx: int | None = None super().__init__() self.config = config self.layer_idx = layer_idx - self.attention_type = config.layer_types[layer_idx] self.attn_norm = ( nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) if layer_idx != 0 @@ -508,11 +507,11 @@ def forward( for layer_type in self.config.layer_types: position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_embeddings=position_embeddings[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], + position_embeddings=position_embeddings[self.config.layer_types[i]], past_key_values=past_key_values, position_ids=position_ids, **kwargs, diff --git a/src/transformers/models/olmo3/modeling_olmo3.py b/src/transformers/models/olmo3/modeling_olmo3.py index 379b63aeacca..5baa8e5f24ed 100644 --- a/src/transformers/models/olmo3/modeling_olmo3.py +++ b/src/transformers/models/olmo3/modeling_olmo3.py @@ -413,10 +413,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.self_attn.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, position_embeddings=position_embeddings, diff --git a/src/transformers/models/olmo3/modular_olmo3.py b/src/transformers/models/olmo3/modular_olmo3.py index 4f1dd96b4d28..bcc304b93785 100644 --- a/src/transformers/models/olmo3/modular_olmo3.py +++ b/src/transformers/models/olmo3/modular_olmo3.py @@ -260,10 +260,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.self_attn.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, position_embeddings=position_embeddings, diff --git a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py index f23bb8b42245..b6c9d83d5bb2 100644 --- a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py @@ -994,9 +994,9 @@ def forward( # RoPE or NoPE position_embeddings = self.rotary_emb(hidden_states, position_ids) if self.rotary_emb is not None else None - for decoder_layer in self.layers: - layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask - layer_position_embeddings = position_embeddings if decoder_layer.layer_type == "full_attention" else None + for i, decoder_layer in enumerate(self.layers): + layer_mask = linear_attn_mask if self.config.layer_types[i] == "linear_attention" else causal_mask + layer_position_embeddings = position_embeddings if self.config.layer_types[i] == "full_attention" else None hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py index 9e3dc808a57d..40bbc5dbe074 100644 --- a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py @@ -743,9 +743,9 @@ def forward( # RoPE or NoPE position_embeddings = self.rotary_emb(hidden_states, position_ids) if self.rotary_emb is not None else None - for decoder_layer in self.layers: - layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask - layer_position_embeddings = position_embeddings if decoder_layer.layer_type == "full_attention" else None + for i, decoder_layer in enumerate(self.layers): + layer_mask = linear_attn_mask if self.config.layer_types[i] == "linear_attention" else causal_mask + layer_position_embeddings = position_embeddings if self.config.layer_types[i] == "full_attention" else None hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8bd0a380066c..9263e1d42937 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -276,7 +276,6 @@ def __init__(self, config: Qwen2Config, layer_idx: int): self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -396,10 +395,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 5a5e6c0d3455..bcc1496d3acf 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -99,9 +99,7 @@ def forward( class Qwen2DecoderLayer(LlamaDecoderLayer): - def __init__(self, config: Qwen2Config, layer_idx: int): - super().__init__(config=config, layer_idx=layer_idx) - self.attention_type = config.layer_types[layer_idx] + pass class Qwen2PreTrainedModel(LlamaPreTrainedModel): @@ -161,10 +159,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 8074eb290f7d..91715a33cf9d 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -301,7 +301,6 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.mlp = Qwen3MLP(config) self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -421,10 +420,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index da76b98f0402..e15625c0ec91 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1355,7 +1355,7 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): - layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + layer_mask = linear_attn_mask if self.config.layer_types[layer_idx] == "linear_attention" else causal_mask hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index f679eafc322f..a229b2372c74 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -580,7 +580,7 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): - layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + layer_mask = linear_attn_mask if self.config.layer_types[layer_idx] == "linear_attention" else causal_mask hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 941fb1a05978..08d003ec5d31 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1480,7 +1480,7 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): - layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + layer_mask = linear_attn_mask if self.config.layer_types[layer_idx] == "linear_attention" else causal_mask hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 03037e88351d..7b45f0ea4838 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -1040,8 +1040,8 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = linear_attn_mask if self.config.layer_types[i] == "linear_attention" else causal_mask hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index d7ef3b035014..a22b85bf9278 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -796,8 +796,8 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = linear_attn_mask if self.config.layer_types[i] == "linear_attention" else causal_mask hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index ed8621cbbee0..4ea7bd693774 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -2451,7 +2451,6 @@ def __init__(self, config, layer_idx): self.mlp = Qwen3OmniMoeMLP(config) self.input_layernorm = Qwen3OmniMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3OmniMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -2624,10 +2623,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, @@ -3628,10 +3627,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 8e548f703845..4368509625b4 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1652,10 +1652,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index 1b3c1c07ee10..8d911e414b0f 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -305,7 +305,6 @@ def __init__(self, config: SmolLM3Config, layer_idx: int): self.mlp = SmolLM3MLP(config) self.input_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -425,10 +424,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/smollm3/modular_smollm3.py b/src/transformers/models/smollm3/modular_smollm3.py index 5dc614b347ea..c19ee7a2a3f6 100644 --- a/src/transformers/models/smollm3/modular_smollm3.py +++ b/src/transformers/models/smollm3/modular_smollm3.py @@ -223,9 +223,7 @@ def forward( class SmolLM3DecoderLayer(LlamaDecoderLayer): - def __init__(self, config: SmolLM3Config, layer_idx: int): - super().__init__(config, layer_idx) - self.attention_type = config.layer_types[layer_idx] + pass class SmolLM3PreTrainedModel(LlamaPreTrainedModel): diff --git a/src/transformers/models/vaultgemma/modeling_vaultgemma.py b/src/transformers/models/vaultgemma/modeling_vaultgemma.py index 1d8b05637ffe..78b1431f1a35 100644 --- a/src/transformers/models/vaultgemma/modeling_vaultgemma.py +++ b/src/transformers/models/vaultgemma/modeling_vaultgemma.py @@ -236,7 +236,6 @@ def __init__(self, config: VaultGemmaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.config = config - self.attention_type = config.layer_types[layer_idx] self.self_attn = VaultGemmaAttention(config=config, layer_idx=layer_idx) self.mlp = VaultGemmaMLP(config) self.input_layernorm = VaultGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -447,10 +446,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values,