Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/transformers/models/afmoe/modeling_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,6 @@ def __init__(self, config: AfmoeConfig, layer_idx: int):
self.layer_idx = layer_idx

self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx)
self.attention_type = config.layer_types[layer_idx]

# Dual normalization for attention
self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -603,10 +602,10 @@ def forward(

position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers:
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/afmoe/modular_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def __init__(self, config: AfmoeConfig, layer_idx: int):
self.layer_idx = layer_idx

self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx)
self.attention_type = config.layer_types[layer_idx]

# Dual normalization for attention
self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -425,10 +424,10 @@ def forward(

position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers:
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,8 +1164,8 @@ def forward(
mamba_mask = self._update_mamba_mask(attention_mask, past_key_values)
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)

for decoder_layer in self.layers:
layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask
for i, decoder_layer in enumerate(self.layers):
layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask

hidden_states, attn_weights = decoder_layer(
hidden_states,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,8 +839,8 @@ def forward(
mamba_mask = self._update_mamba_mask(attention_mask, past_key_values)
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)

for decoder_layer in self.layers:
layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask
for i, decoder_layer in enumerate(self.layers):
layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask

hidden_states, attn_weights = decoder_layer(
hidden_states,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def __init__(self, config: Cohere2Config, layer_idx: int):
self.self_attn = Cohere2Attention(config=config, layer_idx=layer_idx)
self.mlp = Cohere2MLP(config)
self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
self.attention_type = config.layer_types[layer_idx]

def forward(
self,
Expand Down Expand Up @@ -413,10 +412,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers:
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_embeddings=position_embeddings,
past_key_values=past_key_values,
use_cache=use_cache,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@ def forward(
class Cohere2DecoderLayer(CohereDecoderLayer):
def __init__(self, config: Cohere2Config, layer_idx: int):
super().__init__(config, layer_idx)
self.attention_type = config.layer_types[layer_idx]

def forward(
self,
Expand Down Expand Up @@ -329,10 +328,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers:
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_embeddings=position_embeddings,
past_key_values=past_key_values,
use_cache=use_cache,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/cwm/modeling_cwm.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ def __init__(self, config: CwmConfig, layer_idx: int):
self.mlp = CwmMLP(config)
self.input_layernorm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]

def forward(
self,
Expand Down Expand Up @@ -407,10 +406,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_ids=position_ids,
past_key_values=past_key_values,
position_embeddings=position_embeddings,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/cwm/modular_cwm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def __init__(self, config: CwmConfig, layer_idx: int):
class CwmDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: CwmConfig, layer_idx: int):
super().__init__(config=config, layer_idx=layer_idx)
self.attention_type = config.layer_types[layer_idx]
self.self_attn = CwmAttention(config=config, layer_idx=layer_idx)


Expand Down Expand Up @@ -196,10 +195,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_ids=position_ids,
past_key_values=past_key_values,
position_embeddings=position_embeddings,
Expand Down
11 changes: 7 additions & 4 deletions src/transformers/models/dots1/configuration_dots1.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/dots1/modular_dots1.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_dots1.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,10 +20,7 @@

from ...configuration_utils import PreTrainedConfig, layer_type_validation
from ...modeling_rope_utils import RopeParameters
from ...utils import auto_docstring, logging


logger = logging.get_logger(__name__)
from ...utils import auto_docstring


@auto_docstring(checkpoint="rednote-hilab/dots.llm1.base")
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/dots1/modeling_dots1.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,6 @@ def __init__(self, config: Dots1Config, layer_idx: int):

self.input_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]

def forward(
self,
Expand Down Expand Up @@ -549,10 +548,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
Expand Down
139 changes: 134 additions & 5 deletions src/transformers/models/dots1/modular_dots1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.
import torch

from ...configuration_utils import PreTrainedConfig, layer_type_validation
from ...modeling_outputs import CausalLMOutputWithPast
from ...modeling_rope_utils import RopeParameters
from ...processing_utils import Unpack
from ...utils import logging
from ...utils import auto_docstring, logging
from ..deepseek_v3.modeling_deepseek_v3 import (
DeepseekV3DecoderLayer,
DeepseekV3MLP,
Expand All @@ -31,12 +33,140 @@
Qwen3RotaryEmbedding,
TransformersKwargs,
)
from .configuration_dots1 import Dots1Config


logger = logging.get_logger(__name__)


@auto_docstring(checkpoint="rednote-hilab/dots.llm1.base")
class Dots1Config(PreTrainedConfig):
r"""
n_group (`int`, *optional*, defaults to 1):
Number of groups for routed experts.
first_k_dense_replace (`int`, *optional*, defaults to 0):
Number of dense layers at the beginning of the model before the first MoE layer.

Examples:

```python
>>> from transformers import Dots1Model, Dots1Config
>>> # Initializing a Dots1 style configuration
>>> configuration = Dots1Config()
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""

model_type = "dots1"
keys_to_ignore_at_inference = ["past_key_values"]

base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
"layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
"layers.*.mlp.experts.gate_up_proj": "packed_colwise",
"layers.*.mlp.experts.down_proj": "rowwise",
"layers.*.mlp.experts": "moe_tp_experts",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
attribute_map = {
"num_local_experts": "n_routed_experts",
}

def __init__(
self,
vocab_size: int | None = 152064,
hidden_size: int | None = 4608,
intermediate_size: int | None = 10944,
moe_intermediate_size: int | None = 1408,
num_hidden_layers: int | None = 62,
num_attention_heads: int | None = 32,
num_key_value_heads: int | None = 32,
n_shared_experts: int | None = None,
n_routed_experts: int | None = None,
n_group: int | None = 1,
topk_group: int | None = 1,
num_experts_per_tok: int | None = None,
first_k_dense_replace: int | None = 0,
norm_topk_prob: bool | None = False,
hidden_act: str | None = "silu",
max_position_embeddings: int | None = 2048,
initializer_range: float | None = 0.02,
rms_norm_eps: int | None = 1e-6,
use_cache: bool | None = True,
tie_word_embeddings: bool | None = False,
rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
attention_bias: bool | None = False,
attention_dropout: float | None = 0.0,
routed_scaling_factor: float | None = 1.0,
sliding_window: int | None = 4096,
max_window_layers: int | None = 62,
layer_types: list[str] | None = None,
pad_token_id: int | None = None,
bos_token_id: int | None = None,
eos_token_id: int | None = None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.num_experts_per_tok = num_experts_per_tok
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.n_group = n_group
self.topk_group = topk_group
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.routed_scaling_factor = routed_scaling_factor
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers

self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention"
if self.sliding_window is not None and i >= self.max_window_layers
else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types, self.num_hidden_layers)

self.tie_word_embeddings = tie_word_embeddings
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.rope_parameters = rope_parameters

super().__init__(**kwargs)


class Dots1RMSNorm(Qwen3RMSNorm):
pass

Expand Down Expand Up @@ -85,9 +215,7 @@ def route_tokens_to_experts(self, router_logits):


class Dots1DecoderLayer(DeepseekV3DecoderLayer):
def __init__(self, config: Dots1Config, layer_idx: int):
super().__init__(config, layer_idx)
self.attention_type = config.layer_types[layer_idx]
pass


class Dots1PreTrainedModel(DeepseekV3PreTrainedModel):
Expand Down Expand Up @@ -129,6 +257,7 @@ def forward(


__all__ = [
"Dots1Config",
"Dots1PreTrainedModel",
"Dots1Model",
"Dots1ForCausalLM",
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.config = config
self.attention_type = config.layer_types[layer_idx]
self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma2MLP(config)
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -456,10 +455,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
Expand Down
Loading
Loading