Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ defmodule Bumblebee do
"GemmaModel" => {Bumblebee.Text.Gemma, :base},
"GemmaForCausalLM" => {Bumblebee.Text.Gemma, :for_causal_language_modeling},
"GemmaForSequenceClassification" => {Bumblebee.Text.Gemma, :for_sequence_classification},
"Gemma3Model" => {Bumblebee.Text.Gemma3, :base},
"Gemma3ForCausalLM" => {Bumblebee.Text.Gemma3, :for_causal_language_modeling},
"Gemma3ForSequenceClassification" => {Bumblebee.Text.Gemma3, :for_sequence_classification},
"Gemma3TextModel" => {Bumblebee.Text.Gemma3, :base},
"Gemma3TextForCausalLM" => {Bumblebee.Text.Gemma3, :for_causal_language_modeling},
"Gemma3TextForSequenceClassification" =>
{Bumblebee.Text.Gemma3, :for_sequence_classification},
"GPT2ForSequenceClassification" => {Bumblebee.Text.Gpt2, :for_sequence_classification},
"GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification},
"GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling},
Expand Down Expand Up @@ -252,6 +259,8 @@ defmodule Bumblebee do
"camembert" => :camembert,
"clip" => :clip,
"gemma" => :gemma,
"gemma3" => :gemma,
"gemma3_text" => :gemma,
"gpt_neox" => :gpt_neo_x,
"gpt2" => :gpt2,
"gpt_bigcode" => :gpt2,
Expand Down
19 changes: 19 additions & 0 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ defmodule Bumblebee.Layers.Transformer do
- a keyword list (applied to all blocks)
- a function that takes the block index and returns the configuration

* `:attention_window_size` - sliding window attention configuration. Can be:
- `nil` for global attention (default)
- a `{left, right}` tuple (applied to all blocks)
- a function that takes the block index and returns `nil` or `{left, right}`.
This enables per-layer attention patterns like Gemma 3's alternating
local/global attention (5 local layers followed by 1 global layer)

* `:name` - the prefix for layer names

For all other options (including required options) see `block/2`.
Expand Down Expand Up @@ -66,6 +73,7 @@ defmodule Bumblebee.Layers.Transformer do
:name,
:num_blocks,
:rotary_embedding,
:attention_window_size,
attention_mask: Layers.none(),
attention_head_mask: Layers.none(),
attention_relative_bias: nil,
Expand All @@ -87,6 +95,7 @@ defmodule Bumblebee.Layers.Transformer do
cross_attention_head_mask = opts[:cross_attention_head_mask]
cache = opts[:cache]
rotary_embedding = opts[:rotary_embedding]
attention_window_size = opts[:attention_window_size]

block_opts = Keyword.take(opts, block_opts_keys)

Expand Down Expand Up @@ -123,6 +132,15 @@ defmodule Bumblebee.Layers.Transformer do
config when is_list(config) -> config
end

# Support per-layer attention window size for models like Gemma 3
# that alternate between local (sliding window) and global attention
block_attention_window_size =
case attention_window_size do
nil -> nil
fun when is_function(fun, 1) -> fun.(idx)
size -> size
end

{hidden_state, attention, cross_attention, block_cache, attention_relative_bias} =
block(
state.hidden_state,
Expand All @@ -136,6 +154,7 @@ defmodule Bumblebee.Layers.Transformer do
block_cache: block_cache,
offset: offset,
rotary_embedding: block_rotary_embedding,
attention_window_size: block_attention_window_size,
name: join(name, idx)
] ++ block_opts
)
Expand Down
Loading
Loading