diff --git a/src/transformers/models/afmoe/modeling_afmoe.py b/src/transformers/models/afmoe/modeling_afmoe.py index 7fb24d347c00..35642c56a411 100644 --- a/src/transformers/models/afmoe/modeling_afmoe.py +++ b/src/transformers/models/afmoe/modeling_afmoe.py @@ -439,7 +439,6 @@ def __init__(self, config: AfmoeConfig, layer_idx: int): self.layer_idx = layer_idx self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx) - self.attention_type = config.layer_types[layer_idx] # Dual normalization for attention self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -603,10 +602,10 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_value=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/afmoe/modular_afmoe.py b/src/transformers/models/afmoe/modular_afmoe.py index 5fd37eb57e07..6d1b1338b4f0 100644 --- a/src/transformers/models/afmoe/modular_afmoe.py +++ b/src/transformers/models/afmoe/modular_afmoe.py @@ -261,7 +261,6 @@ def __init__(self, config: AfmoeConfig, layer_idx: int): self.layer_idx = layer_idx self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx) - self.attention_type = config.layer_types[layer_idx] # Dual normalization for attention self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -425,10 +424,10 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_value=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 4d2bcf028b07..8888fa5ddbdf 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1164,8 +1164,8 @@ def forward( mamba_mask = self._update_mamba_mask(attention_mask, past_key_values) position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) - for decoder_layer in self.layers: - layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask + for i, decoder_layer in enumerate(self.layers): + layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask hidden_states, attn_weights = decoder_layer( hidden_states, diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index c9370dd0f986..4961025f1743 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -839,8 +839,8 @@ def forward( mamba_mask = self._update_mamba_mask(attention_mask, past_key_values) position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) - for decoder_layer in self.layers: - layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask + for i, decoder_layer in enumerate(self.layers): + layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask hidden_states, attn_weights = decoder_layer( hidden_states, diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index c2597b872b72..f43b2a0ef412 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -290,7 +290,6 @@ def __init__(self, config: Cohere2Config, layer_idx: int): self.self_attn = Cohere2Attention(config=config, layer_idx=layer_idx) self.mlp = Cohere2MLP(config) self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -413,10 +412,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, past_key_values=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 3dedea65db2c..22dec99166d5 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -220,7 +220,6 @@ def forward( class Cohere2DecoderLayer(CohereDecoderLayer): def __init__(self, config: Cohere2Config, layer_idx: int): super().__init__(config, layer_idx) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -301,10 +300,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, past_key_values=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/cwm/modeling_cwm.py b/src/transformers/models/cwm/modeling_cwm.py index 772076754eb2..3e0eb0504be0 100644 --- a/src/transformers/models/cwm/modeling_cwm.py +++ b/src/transformers/models/cwm/modeling_cwm.py @@ -285,7 +285,6 @@ def __init__(self, config: CwmConfig, layer_idx: int): self.mlp = CwmMLP(config) self.input_layernorm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -407,10 +406,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, position_embeddings=position_embeddings, diff --git a/src/transformers/models/cwm/modular_cwm.py b/src/transformers/models/cwm/modular_cwm.py index a7d872fbfba4..321f992ffc9b 100644 --- a/src/transformers/models/cwm/modular_cwm.py +++ b/src/transformers/models/cwm/modular_cwm.py @@ -105,7 +105,6 @@ def __init__(self, config: CwmConfig, layer_idx: int): class CwmDecoderLayer(LlamaDecoderLayer): def __init__(self, config: CwmConfig, layer_idx: int): super().__init__(config=config, layer_idx=layer_idx) - self.attention_type = config.layer_types[layer_idx] self.self_attn = CwmAttention(config=config, layer_idx=layer_idx) @@ -168,10 +167,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, position_embeddings=position_embeddings, diff --git a/src/transformers/models/dots1/configuration_dots1.py b/src/transformers/models/dots1/configuration_dots1.py index ef9ec8407528..b2c3222c618a 100644 --- a/src/transformers/models/dots1/configuration_dots1.py +++ b/src/transformers/models/dots1/configuration_dots1.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/dots1/modular_dots1.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dots1.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 998af195ecd5..399194648663 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -417,7 +417,6 @@ def __init__(self, config: Dots1Config, layer_idx: int): self.input_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -549,10 +548,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/dots1/modular_dots1.py b/src/transformers/models/dots1/modular_dots1.py index 189e28d7e37c..f9c113800e1c 100644 --- a/src/transformers/models/dots1/modular_dots1.py +++ b/src/transformers/models/dots1/modular_dots1.py @@ -13,9 +13,11 @@ # limitations under the License. import torch +from ...configuration_utils import PreTrainedConfig, layer_type_validation from ...modeling_outputs import CausalLMOutputWithPast +from ...modeling_rope_utils import RopeParameters from ...processing_utils import Unpack -from ...utils import logging +from ...utils import auto_docstring, logging from ..deepseek_v3.modeling_deepseek_v3 import ( DeepseekV3DecoderLayer, DeepseekV3MLP, @@ -31,12 +33,140 @@ Qwen3RotaryEmbedding, TransformersKwargs, ) -from .configuration_dots1 import Dots1Config logger = logging.get_logger(__name__) +@auto_docstring(checkpoint="rednote-hilab/dots.llm1.base") +class Dots1Config(PreTrainedConfig): + r""" + n_group (`int`, *optional*, defaults to 1): + Number of groups for routed experts. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers at the beginning of the model before the first MoE layer. + + Examples: + + ```python + >>> from transformers import Dots1Model, Dots1Config + >>> # Initializing a Dots1 style configuration + >>> configuration = Dots1Config() + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "dots1" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", + "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + attribute_map = { + "num_local_experts": "n_routed_experts", + } + + def __init__( + self, + vocab_size: int | None = 152064, + hidden_size: int | None = 4608, + intermediate_size: int | None = 10944, + moe_intermediate_size: int | None = 1408, + num_hidden_layers: int | None = 62, + num_attention_heads: int | None = 32, + num_key_value_heads: int | None = 32, + n_shared_experts: int | None = None, + n_routed_experts: int | None = None, + n_group: int | None = 1, + topk_group: int | None = 1, + num_experts_per_tok: int | None = None, + first_k_dense_replace: int | None = 0, + norm_topk_prob: bool | None = False, + hidden_act: str | None = "silu", + max_position_embeddings: int | None = 2048, + initializer_range: float | None = 0.02, + rms_norm_eps: int | None = 1e-6, + use_cache: bool | None = True, + tie_word_embeddings: bool | None = False, + rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, + attention_bias: bool | None = False, + attention_dropout: float | None = 0.0, + routed_scaling_factor: float | None = 1.0, + sliding_window: int | None = 4096, + max_window_layers: int | None = 62, + layer_types: list[str] | None = None, + pad_token_id: int | None = None, + bos_token_id: int | None = None, + eos_token_id: int | None = None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.num_experts_per_tok = num_experts_per_tok + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.n_group = n_group + self.topk_group = topk_group + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.routed_scaling_factor = routed_scaling_factor + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + + self.tie_word_embeddings = tie_word_embeddings + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.rope_parameters = rope_parameters + + super().__init__(**kwargs) + + class Dots1RMSNorm(Qwen3RMSNorm): pass @@ -85,9 +215,7 @@ def route_tokens_to_experts(self, router_logits): class Dots1DecoderLayer(DeepseekV3DecoderLayer): - def __init__(self, config: Dots1Config, layer_idx: int): - super().__init__(config, layer_idx) - self.attention_type = config.layer_types[layer_idx] + pass class Dots1PreTrainedModel(DeepseekV3PreTrainedModel): @@ -129,6 +257,7 @@ def forward( __all__ = [ + "Dots1Config", "Dots1PreTrainedModel", "Dots1Model", "Dots1ForCausalLM", diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 5b68f113aa51..20673571b2d2 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -304,7 +304,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.config = config - self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -456,10 +455,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 73ae5ba66208..7ea2328260f3 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -271,7 +271,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.config = config - self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -370,10 +369,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index a99e873eaa0a..23607505156f 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -394,7 +394,6 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int): self.config = config self.hidden_size = config.hidden_size self.layer_idx = layer_idx - self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma3MLP(config) self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) @@ -571,11 +570,11 @@ def forward( for layer_type in self.config.layer_types: position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_embeddings=position_embeddings[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], + position_embeddings=position_embeddings[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, **kwargs, diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index cb0456a29da4..b57d16c7609b 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -392,7 +392,6 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int): self.config = config self.hidden_size = config.hidden_size self.layer_idx = layer_idx - self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma3MLP(config) self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) @@ -543,11 +542,11 @@ def forward( for layer_type in self.config.layer_types: position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_embeddings=position_embeddings[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], + position_embeddings=position_embeddings[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, **kwargs, diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 67068f3449b1..d37df841ca17 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1282,7 +1282,6 @@ def __init__(self, config: Gemma3nTextConfig, layer_idx: int): self.config = config self.hidden_size = config.hidden_size self.layer_idx = layer_idx - self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma3nTextAttention(config, layer_idx) self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx) self.input_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) @@ -1700,13 +1699,13 @@ def forward( for layer_type in self.config.layer_types: position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - causal_mask = causal_mask_mapping[decoder_layer.attention_type] - per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :] + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + causal_mask = causal_mask_mapping[self.config.layer_types[i]] + per_layer_input = per_layer_inputs[:, :, i, :] hidden_states = decoder_layer( hidden_states, - position_embeddings[decoder_layer.attention_type], + position_embeddings[self.config.layer_types[i]], per_layer_input, attention_mask=causal_mask, position_ids=position_ids, diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index 208e874cced3..beed89720ab0 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1889,13 +1889,13 @@ def forward( for layer_type in self.config.layer_types: position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - causal_mask = causal_mask_mapping[decoder_layer.attention_type] - per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :] + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + causal_mask = causal_mask_mapping[self.config.layer_types[i]] + per_layer_input = per_layer_inputs[:, :, i, :] hidden_states = decoder_layer( hidden_states, - position_embeddings[decoder_layer.attention_type], + position_embeddings[self.config.layer_types[i]], per_layer_input, attention_mask=causal_mask, position_ids=position_ids, diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index c94ae36ea099..18f31ea90379 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -353,7 +353,6 @@ def __init__(self, config: GptOssConfig, layer_idx: int): self.mlp = GptOssMLP(config) self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -487,10 +486,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index ee175b55fb77..687e8864efeb 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -278,7 +278,6 @@ def __init__(self, config: GptOssConfig, layer_idx: int): self.mlp = GptOssMLP(config) self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -384,10 +383,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 78d3d232bd5e..676ec8f93773 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1312,9 +1312,9 @@ def forward( if self.rotary_emb is not None: position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) - layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask + layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index ead20868aa92..4c72531bddb5 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -245,9 +245,9 @@ def forward( if self.rotary_emb is not None: position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) - layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask + layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index ef79261364ba..b995c5ed1bc5 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -674,8 +674,8 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) # decoder layers - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - layer_mask = causal_mask if decoder_layer.is_attention_layer else linear_attention + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = causal_mask if self.config.layer_types[i] == "full_attention" else linear_attention hidden_states = decoder_layer( hidden_states, attention_mask=layer_mask, diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index 65119a287abd..ce178fabb3f8 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -480,8 +480,8 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) # decoder layers - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - layer_mask = causal_mask if decoder_layer.is_attention_layer else linear_attention + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = causal_mask if self.config.layer_types[i] == "full_attention" else linear_attention hidden_states = decoder_layer( hidden_states, attention_mask=layer_mask, diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py index 03bdbbdc95f8..e9575cd255f8 100644 --- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -764,8 +764,8 @@ def forward( position_embeddings = self.pos_emb(hidden_states, position_ids=position_ids) # decoder layers - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - layer_mask = causal_mask if decoder_layer.is_attention_layer else linear_attention + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = causal_mask if self.config.layer_types[i] == "full_attention" else linear_attention hidden_states = decoder_layer( hidden_states, attention_mask=layer_mask, diff --git a/src/transformers/models/lfm2_moe/modular_lfm2_moe.py b/src/transformers/models/lfm2_moe/modular_lfm2_moe.py index e5847937ce35..7d97d8b70dd5 100644 --- a/src/transformers/models/lfm2_moe/modular_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modular_lfm2_moe.py @@ -195,8 +195,8 @@ def forward( position_embeddings = self.pos_emb(hidden_states, position_ids=position_ids) # decoder layers - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - layer_mask = causal_mask if decoder_layer.is_attention_layer else linear_attention + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = causal_mask if self.config.layer_types[i] == "full_attention" else linear_attention hidden_states = decoder_layer( hidden_states, attention_mask=layer_mask, diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 314d19b75636..08d50bd63f72 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -415,7 +415,6 @@ def __init__(self, config, layer_idx): super().__init__() self.hidden_size = config.hidden_size self.layer_idx = layer_idx - self.attention_type = config.layer_types[layer_idx] self.self_attn = Llama4TextAttention(config, layer_idx) self.is_moe_layer = layer_idx in config.moe_layers if self.is_moe_layer: # the 128E model interleaves dense / sparse @@ -569,10 +568,10 @@ def forward( # create position embeddings to be shared across the decoder layers freq_cis = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index d6eaab3426c5..d6b6871bfe31 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -679,8 +679,8 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: - if decoder_layer.layer_type == "full_attention": + for i, decoder_layer in enumerate(self.layers): + if self.config.layer_types[i] == "full_attention": input_attention_mask = causal_mask else: # lightning attention uses original attention_mask, and uses it only for the first step diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index 07d7540a7686..9e372bdd1ebc 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -462,8 +462,8 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: - if decoder_layer.layer_type == "full_attention": + for i, decoder_layer in enumerate(self.layers): + if self.config.layer_types[i] == "full_attention": input_attention_mask = causal_mask else: # lightning attention uses original attention_mask, and uses it only for the first step diff --git a/src/transformers/models/ministral/modeling_ministral.py b/src/transformers/models/ministral/modeling_ministral.py index 57f39907610f..af4f7fbeae59 100644 --- a/src/transformers/models/ministral/modeling_ministral.py +++ b/src/transformers/models/ministral/modeling_ministral.py @@ -227,7 +227,6 @@ def __init__(self, config: MinistralConfig, layer_idx: int): self.mlp = MinistralMLP(config) self.input_layernorm = MinistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MinistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -409,10 +408,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/ministral/modular_ministral.py b/src/transformers/models/ministral/modular_ministral.py index 2e433d62e9d6..6fe101ba3a54 100644 --- a/src/transformers/models/ministral/modular_ministral.py +++ b/src/transformers/models/ministral/modular_ministral.py @@ -156,10 +156,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index 47294e277aca..82cf25ac2dbc 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -312,7 +312,6 @@ def __init__(self, config: ModernBertDecoderConfig, layer_idx: int | None = None super().__init__() self.config = config self.layer_idx = layer_idx - self.attention_type = config.layer_types[layer_idx] self.attn_norm = ( nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) if layer_idx != 0 @@ -508,11 +507,11 @@ def forward( for layer_type in self.config.layer_types: position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_embeddings=position_embeddings[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], + position_embeddings=position_embeddings[self.config.layer_types[i]], past_key_values=past_key_values, position_ids=position_ids, **kwargs, diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index 8c3ace76b059..5c43e468f44e 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -290,7 +290,6 @@ def __init__(self, config: ModernBertDecoderConfig, layer_idx: int | None = None super().__init__() self.config = config self.layer_idx = layer_idx - self.attention_type = config.layer_types[layer_idx] self.attn_norm = ( nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) if layer_idx != 0 @@ -473,11 +472,11 @@ def forward( for layer_type in self.config.layer_types: position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_embeddings=position_embeddings[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], + position_embeddings=position_embeddings[self.config.layer_types[i]], past_key_values=past_key_values, position_ids=position_ids, **kwargs, diff --git a/src/transformers/models/olmo3/modeling_olmo3.py b/src/transformers/models/olmo3/modeling_olmo3.py index 379b63aeacca..5baa8e5f24ed 100644 --- a/src/transformers/models/olmo3/modeling_olmo3.py +++ b/src/transformers/models/olmo3/modeling_olmo3.py @@ -413,10 +413,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.self_attn.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, position_embeddings=position_embeddings, diff --git a/src/transformers/models/olmo3/modular_olmo3.py b/src/transformers/models/olmo3/modular_olmo3.py index 4be52a8cab52..7ebc4b63ee58 100644 --- a/src/transformers/models/olmo3/modular_olmo3.py +++ b/src/transformers/models/olmo3/modular_olmo3.py @@ -216,10 +216,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.self_attn.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, position_embeddings=position_embeddings, diff --git a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py index 09fd0312b02c..e3b83bff770a 100644 --- a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py @@ -995,9 +995,9 @@ def forward( # RoPE or NoPE position_embeddings = self.rotary_emb(hidden_states, position_ids) if self.rotary_emb is not None else None - for decoder_layer in self.layers: - layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask - layer_position_embeddings = position_embeddings if decoder_layer.layer_type == "full_attention" else None + for i, decoder_layer in enumerate(self.layers): + layer_mask = linear_attn_mask if self.config.layer_types[i] == "linear_attention" else causal_mask + layer_position_embeddings = position_embeddings if self.config.layer_types[i] == "full_attention" else None hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py index f9c9fc9dd1f3..9da82c8b8b8c 100644 --- a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py @@ -707,9 +707,9 @@ def forward( # RoPE or NoPE position_embeddings = self.rotary_emb(hidden_states, position_ids) if self.rotary_emb is not None else None - for decoder_layer in self.layers: - layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask - layer_position_embeddings = position_embeddings if decoder_layer.layer_type == "full_attention" else None + for i, decoder_layer in enumerate(self.layers): + layer_mask = linear_attn_mask if self.config.layer_types[i] == "linear_attention" else causal_mask + layer_position_embeddings = position_embeddings if self.config.layer_types[i] == "full_attention" else None hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8bd0a380066c..9263e1d42937 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -276,7 +276,6 @@ def __init__(self, config: Qwen2Config, layer_idx: int): self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -396,10 +395,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 5a5e6c0d3455..bcc1496d3acf 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -99,9 +99,7 @@ def forward( class Qwen2DecoderLayer(LlamaDecoderLayer): - def __init__(self, config: Qwen2Config, layer_idx: int): - super().__init__(config=config, layer_idx=layer_idx) - self.attention_type = config.layer_types[layer_idx] + pass class Qwen2PreTrainedModel(LlamaPreTrainedModel): @@ -161,10 +159,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 1e7b43d9bc44..7cc8c7d0fab6 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1516,7 +1516,6 @@ def __init__(self, config: Qwen2_5OmniTextConfig, layer_idx: int): self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2_5OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2_5OmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -1667,10 +1666,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=text_position_ids, past_key_values=past_key_values, @@ -2187,10 +2186,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=text_position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index ace1559a0cc0..db822669e943 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -774,7 +774,6 @@ def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int): self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2_5_VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2_5_VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -924,10 +923,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=text_position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 87391926d99b..dab149852456 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -609,7 +609,6 @@ def __init__(self, config: Qwen2VLTextConfig, layer_idx: int): self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2VLRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -892,10 +891,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers: + for i, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=text_position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 8074eb290f7d..91715a33cf9d 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -301,7 +301,6 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.mlp = Qwen3MLP(config) self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -421,10 +420,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 25af85a34a04..11ce02bebcb8 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1357,7 +1357,7 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): - layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + layer_mask = linear_attn_mask if self.config.layer_types[layer_idx] == "linear_attention" else causal_mask hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index bdd7bb42f0a9..77c3978bd58a 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -523,7 +523,7 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): - layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + layer_mask = linear_attn_mask if self.config.layer_types[layer_idx] == "linear_attention" else causal_mask hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 127a31ac8376..5822b94ebe87 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1482,7 +1482,7 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): - layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + layer_mask = linear_attn_mask if self.config.layer_types[layer_idx] == "linear_attention" else causal_mask hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 03037e88351d..7b45f0ea4838 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -1040,8 +1040,8 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = linear_attn_mask if self.config.layer_types[i] == "linear_attention" else causal_mask hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index d7ef3b035014..a22b85bf9278 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -796,8 +796,8 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_mask = linear_attn_mask if self.config.layer_types[i] == "linear_attention" else causal_mask hidden_states = decoder_layer( hidden_states, diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index a575b3e88dae..46f0fa2f3fdf 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -2451,7 +2451,6 @@ def __init__(self, config, layer_idx): self.mlp = Qwen3OmniMoeMLP(config) self.input_layernorm = Qwen3OmniMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3OmniMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -2624,10 +2623,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, @@ -3628,10 +3627,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 6f5ec59f0bbd..c7e7855d88e0 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1409,10 +1409,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index 1b3c1c07ee10..8d911e414b0f 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -305,7 +305,6 @@ def __init__(self, config: SmolLM3Config, layer_idx: int): self.mlp = SmolLM3MLP(config) self.input_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -425,10 +424,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, diff --git a/src/transformers/models/smollm3/modular_smollm3.py b/src/transformers/models/smollm3/modular_smollm3.py index d7ce821d7a95..fd175008a601 100644 --- a/src/transformers/models/smollm3/modular_smollm3.py +++ b/src/transformers/models/smollm3/modular_smollm3.py @@ -191,9 +191,7 @@ def forward( class SmolLM3DecoderLayer(LlamaDecoderLayer): - def __init__(self, config: SmolLM3Config, layer_idx: int): - super().__init__(config, layer_idx) - self.attention_type = config.layer_types[layer_idx] + pass class SmolLM3PreTrainedModel(LlamaPreTrainedModel): diff --git a/src/transformers/models/vaultgemma/modeling_vaultgemma.py b/src/transformers/models/vaultgemma/modeling_vaultgemma.py index 62d9439aa66d..f0a2e48d20b8 100644 --- a/src/transformers/models/vaultgemma/modeling_vaultgemma.py +++ b/src/transformers/models/vaultgemma/modeling_vaultgemma.py @@ -236,7 +236,6 @@ def __init__(self, config: VaultGemmaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.config = config - self.attention_type = config.layer_types[layer_idx] self.self_attn = VaultGemmaAttention(config=config, layer_idx=layer_idx) self.mlp = VaultGemmaMLP(config) self.input_layernorm = VaultGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -447,10 +446,10 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], + attention_mask=causal_mask_mapping[self.config.layer_types[i]], position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values,