Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,10 @@ def generate_step(
temperature: float = DEFAULT_TEMPERATURE,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = DEFAULT_REPETITION_CONTEXT_SIZE,
presence_penalty: Optional[float] = None,
presence_context_size: Optional[int] = None,
frequency_penalty: Optional[float] = None,
frequency_context_size: Optional[int] = None,
top_p: float = DEFAULT_TOP_P,
min_p: float = DEFAULT_MIN_P,
top_k: int = DEFAULT_TOP_K,
Expand All @@ -406,10 +410,21 @@ def generate_step(
mask: The attention mask (optional).
max_tokens (int): Maximum number of tokens to generate.
temperature (float): The temperature for sampling, if 0 the argmax is used.
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_penalty (float, optional): The multiplicative penalty factor
for tokens that have appeared in the recent window. Each unique
repeated token is penalised once regardless of frequency.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty.
consider for repetition penalty (mlx_lm default: 20).
presence_penalty (float, optional): Additive penalty subtracted from a
logit if the token has occurred at least once in the recent window
(OpenAI semantics).
presence_context_size (int, optional): Token window for
``presence_penalty`` (mlx_lm default: 20).
frequency_penalty (float, optional): Additive penalty proportional to a
token's count in the recent window (OpenAI semantics). Effective at
breaking high-frequency repetition loops where a few tokens dominate.
frequency_context_size (int, optional): Token window for
``frequency_penalty`` (mlx_lm default: 20).
top_p (float, optional): Nucleus sampling, higher means model considers
more less likely words.
min_p (float, optional): Minimum probability threshold relative to the
Expand Down Expand Up @@ -452,9 +467,23 @@ def generate_step(
top_k=top_k,
)

processors = make_logits_processors(
logit_bias, repetition_penalty, repetition_context_size
)
# Build the kwargs explicitly so unset fields fall back to mlx_lm defaults
# rather than being overridden by ``None``.
_processor_kwargs: Dict[str, Any] = {
"logit_bias": logit_bias,
"repetition_penalty": repetition_penalty,
"repetition_context_size": repetition_context_size,
}
if presence_penalty is not None:
_processor_kwargs["presence_penalty"] = presence_penalty
if presence_context_size is not None:
_processor_kwargs["presence_context_size"] = presence_context_size
if frequency_penalty is not None:
_processor_kwargs["frequency_penalty"] = frequency_penalty
if frequency_context_size is not None:
_processor_kwargs["frequency_context_size"] = frequency_context_size

processors = make_logits_processors(**_processor_kwargs)
if logits_processors is not None:
processors.extend(logits_processors)

Expand Down Expand Up @@ -809,8 +838,15 @@ def generate(
(default ``False``).
formatter (Optional[Callable]): A function which takes a token and a
probability and displays it.
repetition_penalty (float, optional): The penalty factor for repeating tokens.
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
repetition_penalty (float, optional): Multiplicative penalty for repeated tokens.
repetition_context_size (int, optional): Token window for repetition_penalty.
presence_penalty (float, optional): Additive penalty for any token already seen
in the recent window (OpenAI semantics).
presence_context_size (int, optional): Token window for presence_penalty.
frequency_penalty (float, optional): Additive penalty proportional to a token's
occurrence count in the recent window (OpenAI semantics). Helpful for
breaking high-frequency loops.
frequency_context_size (int, optional): Token window for frequency_penalty.
"""

if verbose:
Expand Down
47 changes: 46 additions & 1 deletion mlx_vlm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,47 @@ class GenerationParams(FlexibleBaseModel):
description="Min-p sampling threshold.",
)
repetition_penalty: Optional[float] = Field(
None, description="Penalty applied to repeated tokens."
None,
description=(
"Multiplicative penalty applied to tokens that have appeared in the "
"recent generation window (Keskar et al. 2019). Each unique repeated "
"token is penalised once regardless of frequency; values below ~1.7 "
"may be insufficient to prevent loops on VLMs whose logits are "
"sharply peaked (e.g. OCR), where 2.0 is a more reliable starting "
"point. Combine with frequency_penalty for finer control."
),
)
repetition_context_size: Optional[int] = Field(
None,
description=(
"Number of previous tokens considered when applying the repetition "
"penalty. mlx_lm's default of 20 is small; raise this for long-form "
"generation (e.g. full-page OCR) where loop cycles exceed 20 tokens."
),
)
presence_penalty: Optional[float] = Field(
None,
description=(
"Additive penalty subtracted from a logit if the token has occurred "
"at least once in the recent window. Mirrors the OpenAI parameter."
),
)
presence_context_size: Optional[int] = Field(
None,
description="Token window for presence_penalty (mlx_lm default: 20).",
)
frequency_penalty: Optional[float] = Field(
None,
description=(
"Additive penalty proportional to a token's count in the recent "
"window. Effective at breaking high-frequency repetition loops "
"(e.g. OCR pages where one heading is repeated dozens of times). "
"Mirrors the OpenAI parameter."
),
)
frequency_context_size: Optional[int] = Field(
None,
description="Token window for frequency_penalty (mlx_lm default: 20).",
)
logit_bias: Optional[dict[int, float]] = Field(
None, description="Additive logit bias keyed by token id."
Expand All @@ -332,6 +372,11 @@ def shared_generation_kwargs(self) -> dict[str, Any]:
"top_k",
"min_p",
"repetition_penalty",
"repetition_context_size",
"presence_penalty",
"presence_context_size",
"frequency_penalty",
"frequency_context_size",
"logit_bias",
)

Expand Down
12 changes: 12 additions & 0 deletions mlx_vlm/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_responses_endpoint_forwards_new_sampling_args(client):
"top_k": 40,
"min_p": 0.08,
"repetition_penalty": 1.15,
"repetition_context_size": 4096,
"logit_bias": {"12": -1.5},
"enable_thinking": False,
"thinking_budget": 24,
Expand All @@ -82,6 +83,7 @@ def test_responses_endpoint_forwards_new_sampling_args(client):
assert mock_generate.call_args.kwargs["top_k"] == 40
assert mock_generate.call_args.kwargs["min_p"] == 0.08
assert mock_generate.call_args.kwargs["repetition_penalty"] == 1.15
assert mock_generate.call_args.kwargs["repetition_context_size"] == 4096
assert mock_generate.call_args.kwargs["logit_bias"] == {12: -1.5}
assert mock_generate.call_args.kwargs["enable_thinking"] is False
assert mock_generate.call_args.kwargs["thinking_budget"] == 24
Expand Down Expand Up @@ -118,6 +120,11 @@ def test_chat_completions_endpoint_forwards_explicit_sampling_args(client):
"top_k": 40,
"min_p": 0.08,
"repetition_penalty": 1.15,
"repetition_context_size": 4096,
"presence_penalty": 0.3,
"presence_context_size": 256,
"frequency_penalty": 0.05,
"frequency_context_size": 512,
"logit_bias": {"12": -1.5},
"resize_shape": [512],
},
Expand All @@ -128,5 +135,10 @@ def test_chat_completions_endpoint_forwards_explicit_sampling_args(client):
assert mock_generate.call_args.kwargs["top_k"] == 40
assert mock_generate.call_args.kwargs["min_p"] == 0.08
assert mock_generate.call_args.kwargs["repetition_penalty"] == 1.15
assert mock_generate.call_args.kwargs["repetition_context_size"] == 4096
assert mock_generate.call_args.kwargs["presence_penalty"] == 0.3
assert mock_generate.call_args.kwargs["presence_context_size"] == 256
assert mock_generate.call_args.kwargs["frequency_penalty"] == 0.05
assert mock_generate.call_args.kwargs["frequency_context_size"] == 512
assert mock_generate.call_args.kwargs["logit_bias"] == {12: -1.5}
assert mock_generate.call_args.kwargs["resize_shape"] == (512, 512)