diff --git a/mlx_vlm/server.py b/mlx_vlm/server.py index 04c011f6a..d7c66ab09 100644 --- a/mlx_vlm/server.py +++ b/mlx_vlm/server.py @@ -602,6 +602,13 @@ class UsageStats(OpenAIUsage): class ChatRequest(GenerationRequest): messages: List[ChatMessage] + tools: Optional[List[dict]] = Field( + None, description="Tool definitions the model may call.", + ) + tool_choice: Optional[Any] = Field( + None, + description='Tool choice: "none", "auto", "required", or specific tool.', + ) class ChatChoice(BaseModel): @@ -630,6 +637,68 @@ class ChatStreamChunk(BaseModel): usage: Optional[UsageStats] +class InvalidToolChoiceError(ValueError): + """Raised when tool_choice is invalid.""" + + +def resolve_tool_choice( + tools: Optional[list], + tool_choice: Optional[Any], +) -> tuple[Optional[list], Optional[str]]: + """Apply tool_choice policy to the tools list. + + Args: + tools: The original tools list from the request. + tool_choice: The tool_choice value (``"none"``, ``"auto"``, + ``"required"``, or a dict specifying a particular tool). + + Returns: + Tuple of (filtered_tools, system_instruction). + ``filtered_tools`` may be ``None`` (when choice is ``"none"``). + ``system_instruction`` is an optional string to prepend to the + prompt to steer the model's tool-use behavior. + + Raises: + InvalidToolChoiceError: If tool_choice is not a recognized value + or references an unknown tool name. + """ + if not tools or tool_choice is None or tool_choice == "auto": + return tools, None + + if tool_choice == "none": + return None, None + + if tool_choice == "required": + return tools, "You must call one of the available tools to answer this request." + + # Specific tool: {"type": "function", "function": {"name": "..."}} + if isinstance(tool_choice, dict): + func = tool_choice.get("function", {}) + name = func.get("name") if isinstance(func, dict) else None + if name: + filtered = [ + t for t in tools + if (t.get("function", {}) or {}).get("name") == name + or t.get("name") == name + ] + if not filtered: + raise InvalidToolChoiceError( + f"Tool '{name}' not found in the provided tools list" + ) + return ( + filtered, + f'You must call the "{name}" tool to answer this request.', + ) + + if isinstance(tool_choice, str): + raise InvalidToolChoiceError( + f"Invalid tool_choice value: '{tool_choice}'. " + "Must be 'none', 'auto', 'required', or a tool specification dict." + ) + + return tools, None + + def build_generation_kwargs( request: Any, template_kwargs: dict[str, Any], @@ -1095,11 +1164,20 @@ async def chat_completions_endpoint(request: ChatRequest): if hasattr(request, "tools"): tools = request.tools + # Apply tool_choice policy + tool_choice = getattr(request, "tool_choice", None) + try: + tools, tool_instruction = resolve_tool_choice(tools, tool_choice) + except InvalidToolChoiceError as e: + raise HTTPException(status_code=400, detail=str(e)) + if tool_instruction: + processed_messages.insert(0, {"role": "system", "content": tool_instruction}) + tool_parser_type = None tokenizer = ( processor.tokenizer if hasattr(processor, "tokenizer") else processor ) - if hasattr(tokenizer, "chat_template"): + if tools and hasattr(tokenizer, "chat_template"): tool_parser_type = _infer_tool_parser(tokenizer.chat_template) if tool_parser_type is not None: tool_module = load_tool_module(tool_parser_type) diff --git a/mlx_vlm/tests/test_server.py b/mlx_vlm/tests/test_server.py index 270a82d77..b0deab74f 100644 --- a/mlx_vlm/tests/test_server.py +++ b/mlx_vlm/tests/test_server.py @@ -130,3 +130,147 @@ def test_chat_completions_endpoint_forwards_explicit_sampling_args(client): assert mock_generate.call_args.kwargs["repetition_penalty"] == 1.15 assert mock_generate.call_args.kwargs["logit_bias"] == {12: -1.5} assert mock_generate.call_args.kwargs["resize_shape"] == (512, 512) + + +# --------------------------------------------------------------------------- +# tool_choice tests +# --------------------------------------------------------------------------- + + +def test_resolve_tool_choice_auto_passthrough(): + """tool_choice='auto' should return tools unchanged.""" + tools = [{"function": {"name": "search"}}] + result_tools, instruction = server.resolve_tool_choice(tools, "auto") + assert result_tools == tools + assert instruction is None + + +def test_resolve_tool_choice_none_strips_tools(): + """tool_choice='none' should return None for tools.""" + tools = [{"function": {"name": "search"}}] + result_tools, instruction = server.resolve_tool_choice(tools, "none") + assert result_tools is None + assert instruction is None + + +def test_resolve_tool_choice_required_adds_instruction(): + """tool_choice='required' should keep tools and add instruction.""" + tools = [{"function": {"name": "search"}}] + result_tools, instruction = server.resolve_tool_choice(tools, "required") + assert result_tools == tools + assert instruction is not None + assert "must call" in instruction.lower() + + +def test_resolve_tool_choice_specific_function(): + """Specific function tool_choice should filter to that tool.""" + tools = [ + {"function": {"name": "search"}}, + {"function": {"name": "fetch"}}, + {"function": {"name": "read"}}, + ] + choice = {"type": "function", "function": {"name": "fetch"}} + result_tools, instruction = server.resolve_tool_choice(tools, choice) + assert len(result_tools) == 1 + assert result_tools[0]["function"]["name"] == "fetch" + assert "fetch" in instruction + + +def test_resolve_tool_choice_unknown_tool_returns_400(): + """Unknown function name should raise InvalidToolChoiceError.""" + tools = [{"function": {"name": "search"}}] + choice = {"type": "function", "function": {"name": "nonexistent"}} + with pytest.raises(server.InvalidToolChoiceError, match="not found"): + server.resolve_tool_choice(tools, choice) + + +def test_resolve_tool_choice_invalid_string_returns_error(): + """Invalid string tool_choice should raise InvalidToolChoiceError.""" + tools = [{"function": {"name": "search"}}] + with pytest.raises(server.InvalidToolChoiceError, match="Invalid tool_choice"): + server.resolve_tool_choice(tools, "bogus") + + +def test_resolve_tool_choice_none_value_passthrough(): + """tool_choice=None should return tools unchanged.""" + tools = [{"function": {"name": "search"}}] + result_tools, instruction = server.resolve_tool_choice(tools, None) + assert result_tools == tools + assert instruction is None + + +def test_resolve_tool_choice_no_tools(): + """No tools should return None regardless of tool_choice.""" + result_tools, instruction = server.resolve_tool_choice(None, "required") + assert result_tools is None + assert instruction is None + + +def test_chat_completions_tool_choice_none_strips_tools(client): + """tool_choice='none' should not pass tools to apply_chat_template.""" + model = SimpleNamespace() + processor = SimpleNamespace(tokenizer=SimpleNamespace(chat_template="")) + config = SimpleNamespace(model_type="test") + result = SimpleNamespace( + text="Hi", + prompt_tokens=5, + generation_tokens=1, + total_tokens=6, + prompt_tps=100.0, + generation_tps=50.0, + peak_memory=1.0, + ) + + with ( + patch.object(server, "get_cached_model", return_value=(model, processor, config)), + patch.object(server, "apply_chat_template", return_value="prompt") as mock_tmpl, + patch.object(server, "generate", return_value=result), + ): + resp = client.post( + "/chat/completions", + json={ + "model": "demo", + "messages": [{"role": "user", "content": "hello"}], + "tools": [{"type": "function", "function": {"name": "search", "parameters": {}}}], + "tool_choice": "none", + }, + ) + assert resp.status_code == 200 + # tools should be None in the template call + assert mock_tmpl.call_args.kwargs.get("tools") is None + + +def test_chat_completions_tool_choice_required_adds_system_msg(client): + """tool_choice='required' should inject a system message.""" + model = SimpleNamespace() + processor = SimpleNamespace(tokenizer=SimpleNamespace(chat_template="")) + config = SimpleNamespace(model_type="test") + result = SimpleNamespace( + text="Hi", + prompt_tokens=5, + generation_tokens=1, + total_tokens=6, + prompt_tps=100.0, + generation_tps=50.0, + peak_memory=1.0, + ) + + with ( + patch.object(server, "get_cached_model", return_value=(model, processor, config)), + patch.object(server, "apply_chat_template", return_value="prompt") as mock_tmpl, + patch.object(server, "generate", return_value=result), + ): + resp = client.post( + "/chat/completions", + json={ + "model": "demo", + "messages": [{"role": "user", "content": "search for test"}], + "tools": [{"type": "function", "function": {"name": "search", "parameters": {}}}], + "tool_choice": "required", + }, + ) + assert resp.status_code == 200 + # Check that messages passed to template include the system instruction + messages_arg = mock_tmpl.call_args[0][2] # 3rd positional arg + system_msgs = [m for m in messages_arg if m.get("role") == "system"] + assert any("must call" in m.get("content", "").lower() for m in system_msgs)