Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion mlx_vlm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,13 @@ class UsageStats(OpenAIUsage):

class ChatRequest(GenerationRequest):
messages: List[ChatMessage]
tools: Optional[List[dict]] = Field(
None, description="Tool definitions the model may call.",
)
tool_choice: Optional[Any] = Field(
None,
description='Tool choice: "none", "auto", "required", or specific tool.',
)


class ChatChoice(BaseModel):
Expand Down Expand Up @@ -630,6 +637,68 @@ class ChatStreamChunk(BaseModel):
usage: Optional[UsageStats]


class InvalidToolChoiceError(ValueError):
"""Raised when tool_choice is invalid."""


def resolve_tool_choice(
tools: Optional[list],
tool_choice: Optional[Any],
) -> tuple[Optional[list], Optional[str]]:
"""Apply tool_choice policy to the tools list.

Args:
tools: The original tools list from the request.
tool_choice: The tool_choice value (``"none"``, ``"auto"``,
``"required"``, or a dict specifying a particular tool).

Returns:
Tuple of (filtered_tools, system_instruction).
``filtered_tools`` may be ``None`` (when choice is ``"none"``).
``system_instruction`` is an optional string to prepend to the
prompt to steer the model's tool-use behavior.

Raises:
InvalidToolChoiceError: If tool_choice is not a recognized value
or references an unknown tool name.
"""
if not tools or tool_choice is None or tool_choice == "auto":
return tools, None

if tool_choice == "none":
return None, None

if tool_choice == "required":
return tools, "You must call one of the available tools to answer this request."

# Specific tool: {"type": "function", "function": {"name": "..."}}
if isinstance(tool_choice, dict):
func = tool_choice.get("function", {})
name = func.get("name") if isinstance(func, dict) else None
if name:
filtered = [
t for t in tools
if (t.get("function", {}) or {}).get("name") == name
or t.get("name") == name
]
if not filtered:
raise InvalidToolChoiceError(
f"Tool '{name}' not found in the provided tools list"
)
return (
filtered,
f'You must call the "{name}" tool to answer this request.',
)

if isinstance(tool_choice, str):
raise InvalidToolChoiceError(
f"Invalid tool_choice value: '{tool_choice}'. "
"Must be 'none', 'auto', 'required', or a tool specification dict."
)

return tools, None


def build_generation_kwargs(
request: Any,
template_kwargs: dict[str, Any],
Expand Down Expand Up @@ -1095,11 +1164,20 @@ async def chat_completions_endpoint(request: ChatRequest):
if hasattr(request, "tools"):
tools = request.tools

# Apply tool_choice policy
tool_choice = getattr(request, "tool_choice", None)
try:
tools, tool_instruction = resolve_tool_choice(tools, tool_choice)
except InvalidToolChoiceError as e:
raise HTTPException(status_code=400, detail=str(e))
if tool_instruction:
processed_messages.insert(0, {"role": "system", "content": tool_instruction})

tool_parser_type = None
tokenizer = (
processor.tokenizer if hasattr(processor, "tokenizer") else processor
)
if hasattr(tokenizer, "chat_template"):
if tools and hasattr(tokenizer, "chat_template"):
tool_parser_type = _infer_tool_parser(tokenizer.chat_template)
if tool_parser_type is not None:
tool_module = load_tool_module(tool_parser_type)
Expand Down
144 changes: 144 additions & 0 deletions mlx_vlm/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,147 @@ def test_chat_completions_endpoint_forwards_explicit_sampling_args(client):
assert mock_generate.call_args.kwargs["repetition_penalty"] == 1.15
assert mock_generate.call_args.kwargs["logit_bias"] == {12: -1.5}
assert mock_generate.call_args.kwargs["resize_shape"] == (512, 512)


# ---------------------------------------------------------------------------
# tool_choice tests
# ---------------------------------------------------------------------------


def test_resolve_tool_choice_auto_passthrough():
"""tool_choice='auto' should return tools unchanged."""
tools = [{"function": {"name": "search"}}]
result_tools, instruction = server.resolve_tool_choice(tools, "auto")
assert result_tools == tools
assert instruction is None


def test_resolve_tool_choice_none_strips_tools():
"""tool_choice='none' should return None for tools."""
tools = [{"function": {"name": "search"}}]
result_tools, instruction = server.resolve_tool_choice(tools, "none")
assert result_tools is None
assert instruction is None


def test_resolve_tool_choice_required_adds_instruction():
"""tool_choice='required' should keep tools and add instruction."""
tools = [{"function": {"name": "search"}}]
result_tools, instruction = server.resolve_tool_choice(tools, "required")
assert result_tools == tools
assert instruction is not None
assert "must call" in instruction.lower()


def test_resolve_tool_choice_specific_function():
"""Specific function tool_choice should filter to that tool."""
tools = [
{"function": {"name": "search"}},
{"function": {"name": "fetch"}},
{"function": {"name": "read"}},
]
choice = {"type": "function", "function": {"name": "fetch"}}
result_tools, instruction = server.resolve_tool_choice(tools, choice)
assert len(result_tools) == 1
assert result_tools[0]["function"]["name"] == "fetch"
assert "fetch" in instruction


def test_resolve_tool_choice_unknown_tool_returns_400():
"""Unknown function name should raise InvalidToolChoiceError."""
tools = [{"function": {"name": "search"}}]
choice = {"type": "function", "function": {"name": "nonexistent"}}
with pytest.raises(server.InvalidToolChoiceError, match="not found"):
server.resolve_tool_choice(tools, choice)


def test_resolve_tool_choice_invalid_string_returns_error():
"""Invalid string tool_choice should raise InvalidToolChoiceError."""
tools = [{"function": {"name": "search"}}]
with pytest.raises(server.InvalidToolChoiceError, match="Invalid tool_choice"):
server.resolve_tool_choice(tools, "bogus")


def test_resolve_tool_choice_none_value_passthrough():
"""tool_choice=None should return tools unchanged."""
tools = [{"function": {"name": "search"}}]
result_tools, instruction = server.resolve_tool_choice(tools, None)
assert result_tools == tools
assert instruction is None


def test_resolve_tool_choice_no_tools():
"""No tools should return None regardless of tool_choice."""
result_tools, instruction = server.resolve_tool_choice(None, "required")
assert result_tools is None
assert instruction is None


def test_chat_completions_tool_choice_none_strips_tools(client):
"""tool_choice='none' should not pass tools to apply_chat_template."""
model = SimpleNamespace()
processor = SimpleNamespace(tokenizer=SimpleNamespace(chat_template=""))
config = SimpleNamespace(model_type="test")
result = SimpleNamespace(
text="Hi",
prompt_tokens=5,
generation_tokens=1,
total_tokens=6,
prompt_tps=100.0,
generation_tps=50.0,
peak_memory=1.0,
)

with (
patch.object(server, "get_cached_model", return_value=(model, processor, config)),
patch.object(server, "apply_chat_template", return_value="prompt") as mock_tmpl,
patch.object(server, "generate", return_value=result),
):
resp = client.post(
"/chat/completions",
json={
"model": "demo",
"messages": [{"role": "user", "content": "hello"}],
"tools": [{"type": "function", "function": {"name": "search", "parameters": {}}}],
"tool_choice": "none",
},
)
assert resp.status_code == 200
# tools should be None in the template call
assert mock_tmpl.call_args.kwargs.get("tools") is None


def test_chat_completions_tool_choice_required_adds_system_msg(client):
"""tool_choice='required' should inject a system message."""
model = SimpleNamespace()
processor = SimpleNamespace(tokenizer=SimpleNamespace(chat_template=""))
config = SimpleNamespace(model_type="test")
result = SimpleNamespace(
text="Hi",
prompt_tokens=5,
generation_tokens=1,
total_tokens=6,
prompt_tps=100.0,
generation_tps=50.0,
peak_memory=1.0,
)

with (
patch.object(server, "get_cached_model", return_value=(model, processor, config)),
patch.object(server, "apply_chat_template", return_value="prompt") as mock_tmpl,
patch.object(server, "generate", return_value=result),
):
resp = client.post(
"/chat/completions",
json={
"model": "demo",
"messages": [{"role": "user", "content": "search for test"}],
"tools": [{"type": "function", "function": {"name": "search", "parameters": {}}}],
"tool_choice": "required",
},
)
assert resp.status_code == 200
# Check that messages passed to template include the system instruction
messages_arg = mock_tmpl.call_args[0][2] # 3rd positional arg
system_msgs = [m for m in messages_arg if m.get("role") == "system"]
assert any("must call" in m.get("content", "").lower() for m in system_msgs)