diff --git a/mlx_vlm/server.py b/mlx_vlm/server.py index 04c011f6a..1a97da9f8 100644 --- a/mlx_vlm/server.py +++ b/mlx_vlm/server.py @@ -1,4 +1,5 @@ import argparse +import asyncio import gc import json import os @@ -12,7 +13,7 @@ import mlx.core as mx import uvicorn -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from huggingface_hub import scan_cache_dir @@ -43,12 +44,49 @@ DEFAULT_SERVER_HOST = "0.0.0.0" DEFAULT_SERVER_PORT = 8080 +DEFAULT_REQUEST_TIMEOUT = 300 + + +def get_request_timeout() -> int: + """Request timeout in seconds. Must be > 0.""" + try: + value = int(os.environ.get("REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)) + except ValueError: + raise ValueError( + f"REQUEST_TIMEOUT must be a valid integer, got: " + f"{os.environ.get('REQUEST_TIMEOUT')!r}" + ) + if value <= 0: + raise ValueError(f"REQUEST_TIMEOUT must be > 0, got {value}") + return value def get_prefill_step_size(): return int(os.environ.get("PREFILL_STEP_SIZE", DEFAULT_PREFILL_STEP_SIZE)) +def get_max_context_tokens() -> int: + """Maximum prompt tokens before rejecting a request. 0 means no limit.""" + value = int(os.environ.get("MAX_CONTEXT_TOKENS", 0)) + if value < 0: + raise ValueError(f"MAX_CONTEXT_TOKENS must be >= 0, got {value}") + return value + + +def check_context_length(prompt: str, processor, max_context: int) -> None: + """Raise HTTP 400 if the tokenized prompt exceeds *max_context* tokens.""" + if max_context <= 0: + return + tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor + token_count = len(tokenizer.encode(prompt, add_special_tokens=False)) + if token_count > max_context: + raise HTTPException( + status_code=400, + detail=f"Prompt length ({token_count} tokens) exceeds maximum context " + f"window ({max_context} tokens). Reduce your prompt or increase --max-context-tokens.", + ) + + def get_quantized_kv_bits(model: str): kv_bits = float(os.environ.get("KV_BITS", 0)) if kv_bits == 0: @@ -382,6 +420,10 @@ class OpenAIRequest(GenerationParams, TemplateParams): stream: bool = Field( False, description="Whether to stream the response chunk by chunk." ) + response_format: Optional[dict] = Field( + None, + description='Output format: {"type": "text"} or {"type": "json_object"}.', + ) def generation_kwargs(self) -> dict[str, Any]: kwargs = self.dump_kwargs("max_output_tokens") @@ -602,6 +644,10 @@ class UsageStats(OpenAIUsage): class ChatRequest(GenerationRequest): messages: List[ChatMessage] + response_format: Optional[dict] = Field( + None, + description='Output format: {"type": "text"} or {"type": "json_object"}.', + ) class ChatChoice(BaseModel): @@ -630,6 +676,29 @@ class ChatStreamChunk(BaseModel): usage: Optional[UsageStats] +def resolve_response_format(messages, response_format): + """Inject JSON instruction if json_object format requested. + + Returns a new list — the original messages list is not mutated. + """ + if not response_format: + return messages + fmt_type = response_format.get("type", "text") + if fmt_type not in ("text", "json_object"): + raise HTTPException( + status_code=400, + detail=f"Unsupported response_format type: '{fmt_type}'. " + "Supported types are 'text' and 'json_object'.", + ) + if fmt_type == "json_object": + json_instruction = ( + "You must respond with valid JSON only. " + "Do not include any text outside the JSON object." + ) + return [{"role": "system", "content": json_instruction}] + messages + return messages + + def build_generation_kwargs( request: Any, template_kwargs: dict[str, Any], @@ -837,6 +906,10 @@ def run_openai(prompt, img_url,system, stream=False, max_output_tokens=512, mode print("no input") raise HTTPException(status_code=400, detail="Missing input.") + chat_messages = resolve_response_format( + chat_messages, openai_request.response_format + ) + template_kwargs = openai_request.template_kwargs() formatted_prompt = apply_chat_template( processor, @@ -845,6 +918,7 @@ def run_openai(prompt, img_url,system, stream=False, max_output_tokens=512, mode num_images=len(images), **template_kwargs, ) + check_context_length(formatted_prompt, processor, get_max_context_tokens()) generation_kwargs = build_generation_kwargs(openai_request, template_kwargs) generated_at = datetime.now().timestamp() @@ -982,15 +1056,35 @@ async def stream_generator(): else: # Non-streaming response try: - # Use generate from generate.py - result = generate( - model=model, - processor=processor, - prompt=formatted_prompt, - image=images, - verbose=False, # stats are passed in the response - **generation_kwargs, - ) + # Use generate from generate.py, with request timeout. + # NOTE: wait_for cancels the future but cannot interrupt the + # sync generate() running in the thread pool. The thread will + # run to completion; only the await is aborted. + timeout = get_request_timeout() + loop = asyncio.get_running_loop() + try: + result = await asyncio.wait_for( + loop.run_in_executor( + None, + lambda: generate( + model=model, + processor=processor, + prompt=formatted_prompt, + image=images, + verbose=False, + **generation_kwargs, + ), + ), + timeout=timeout, + ) + except asyncio.TimeoutError: + print(f"[cancellation] /responses generation timed out after {timeout}s.") + mx.clear_cache() + gc.collect() + raise HTTPException( + status_code=504, + detail=f"Generation timed out after {timeout} seconds.", + ) # Clean up resources mx.clear_cache() gc.collect() @@ -1051,7 +1145,7 @@ async def stream_generator(): "/chat/completions", response_model=None ) # Response model handled dynamically based on stream flag @app.post("/v1/chat/completions", response_model=None, include_in_schema=False) -async def chat_completions_endpoint(request: ChatRequest): +async def chat_completions_endpoint(request: ChatRequest, raw_request: Request): """ Generate text based on a prompt and optional images. Prompt must be a list of chat messages, including system, user, and assistant messages. @@ -1103,6 +1197,9 @@ async def chat_completions_endpoint(request: ChatRequest): tool_parser_type = _infer_tool_parser(tokenizer.chat_template) if tool_parser_type is not None: tool_module = load_tool_module(tool_parser_type) + processed_messages = resolve_response_format( + processed_messages, request.response_format + ) template_kwargs = request.template_kwargs() formatted_prompt = apply_chat_template( processor, @@ -1113,6 +1210,7 @@ async def chat_completions_endpoint(request: ChatRequest): tools=tools, **template_kwargs, ) + check_context_length(formatted_prompt, processor, get_max_context_tokens()) generation_kwargs = build_generation_kwargs(request, template_kwargs) if request.stream: @@ -1134,6 +1232,13 @@ async def stream_generator(): output_text = "" request_id = f"chatcmpl-{uuid.uuid4()}" for chunk in token_iterator: + # Check if client disconnected + if await raw_request.is_disconnected(): + print("[cancellation] Client disconnected during /chat/completions streaming, aborting generation.") + if token_iterator is not None: + token_iterator.close() + return + if chunk is None or not hasattr(chunk, "text"): print("Warning: Received unexpected chunk format:", chunk) continue @@ -1199,6 +1304,12 @@ async def stream_generator(): yield "data: [DONE]\n\n" + except asyncio.CancelledError: + print("[cancellation] /chat/completions stream cancelled (client disconnect).") + if token_iterator is not None: + token_iterator.close() + raise + except Exception as e: print(f"Error during stream generation: {e}") traceback.print_exc() @@ -1223,17 +1334,37 @@ async def stream_generator(): else: # Non-streaming response try: - # Use generate from generate.py - gen_result = generate( - model=model, - processor=processor, - prompt=formatted_prompt, - image=images, - audio=audio, - verbose=False, # Keep API output clean - vision_cache=model_cache.get("vision_cache"), - **generation_kwargs, - ) + # Use generate from generate.py, with request timeout. + # NOTE: wait_for cancels the future but cannot interrupt the + # sync generate() running in the thread pool. The thread will + # run to completion; only the await is aborted. + timeout = get_request_timeout() + loop = asyncio.get_running_loop() + try: + gen_result = await asyncio.wait_for( + loop.run_in_executor( + None, + lambda: generate( + model=model, + processor=processor, + prompt=formatted_prompt, + image=images, + audio=audio, + verbose=False, + vision_cache=model_cache.get("vision_cache"), + **generation_kwargs, + ), + ), + timeout=timeout, + ) + except asyncio.TimeoutError: + print(f"[cancellation] /chat/completions generation timed out after {timeout}s.") + mx.clear_cache() + gc.collect() + raise HTTPException( + status_code=504, + detail=f"Generation timed out after {timeout} seconds.", + ) # Clean up resources mx.clear_cache() gc.collect() @@ -1433,6 +1564,13 @@ def main(): default=DEFAULT_QUANTIZED_KV_START, help="Start index (of token) for the quantized KV cache.", ) + parser.add_argument( + "--max-context-tokens", + type=int, + default=0, + help="Maximum context window in tokens. Requests exceeding this are rejected. " + "0 means no limit. (default: %(default)s)", + ) parser.add_argument( "--reload", action="store_true", @@ -1454,6 +1592,7 @@ def main(): os.environ["KV_QUANT_SCHEME"] = args.kv_quant_scheme os.environ["MAX_KV_SIZE"] = str(args.max_kv_size) os.environ["QUANTIZED_KV_START"] = str(args.quantized_kv_start) + os.environ["MAX_CONTEXT_TOKENS"] = str(args.max_context_tokens) uvicorn.run( "mlx_vlm.server:app", diff --git a/mlx_vlm/tests/test_server.py b/mlx_vlm/tests/test_server.py index 270a82d77..0ddbd3980 100644 --- a/mlx_vlm/tests/test_server.py +++ b/mlx_vlm/tests/test_server.py @@ -130,3 +130,164 @@ def test_chat_completions_endpoint_forwards_explicit_sampling_args(client): assert mock_generate.call_args.kwargs["repetition_penalty"] == 1.15 assert mock_generate.call_args.kwargs["logit_bias"] == {12: -1.5} assert mock_generate.call_args.kwargs["resize_shape"] == (512, 512) + + +# --------------------------------------------------------------------------- +# Context tracking tests +# --------------------------------------------------------------------------- + + +def test_check_context_length_within_limit(): + """Should not raise when within limit.""" + fake_proc = SimpleNamespace( + tokenizer=SimpleNamespace(encode=lambda s, add_special_tokens=False: list(range(10))), + ) + server.check_context_length("short", fake_proc, 100) # No exception + + +def test_check_context_length_exceeds_limit(): + """Should raise HTTPException when exceeding limit.""" + from fastapi import HTTPException as _HTTPException + + fake_proc = SimpleNamespace( + tokenizer=SimpleNamespace(encode=lambda s, add_special_tokens=False: list(range(200))), + ) + with pytest.raises(_HTTPException) as exc_info: + server.check_context_length("long", fake_proc, 100) + assert exc_info.value.status_code == 400 + assert "200 tokens" in exc_info.value.detail + + +def test_check_context_length_zero_unlimited(): + """max_context=0 should skip check entirely.""" + server.check_context_length("anything", None, 0) # No exception + + +def test_get_max_context_tokens_default(monkeypatch): + """Default should be 0 (unlimited).""" + monkeypatch.delenv("MAX_CONTEXT_TOKENS", raising=False) + assert server.get_max_context_tokens() == 0 + + +def test_get_max_context_tokens_from_env(monkeypatch): + """Should read from MAX_CONTEXT_TOKENS env var.""" + monkeypatch.setenv("MAX_CONTEXT_TOKENS", "16384") + assert server.get_max_context_tokens() == 16384 + + +def test_get_max_context_tokens_rejects_negative(monkeypatch): + """Negative values should be rejected.""" + monkeypatch.setenv("MAX_CONTEXT_TOKENS", "-1") + with pytest.raises(ValueError, match="must be >= 0"): + server.get_max_context_tokens() + + +# --------------------------------------------------------------------------- +# JSON mode / response_format tests +# --------------------------------------------------------------------------- + + +def test_resolve_response_format_json_adds_instruction(): + msgs = [{"role": "user", "content": "hi"}] + result = server.resolve_response_format(msgs, {"type": "json_object"}) + assert result[0]["role"] == "system" + assert "json" in result[0]["content"].lower() + assert len(result) == 2 + # Original list should not be mutated + assert len(msgs) == 1 + + +def test_resolve_response_format_text_no_change(): + msgs = [{"role": "user", "content": "hi"}] + result = server.resolve_response_format(msgs, {"type": "text"}) + assert len(result) == 1 + + +def test_resolve_response_format_none_no_change(): + msgs = [{"role": "user", "content": "hi"}] + result = server.resolve_response_format(msgs, None) + assert len(result) == 1 + + +def test_resolve_response_format_unsupported_type(): + """Unsupported response_format type should raise 400.""" + from fastapi import HTTPException as _Exc + msgs = [{"role": "user", "content": "hi"}] + with pytest.raises(_Exc) as exc_info: + server.resolve_response_format(msgs, {"type": "xml"}) + assert exc_info.value.status_code == 400 + + +def test_chat_completions_json_mode_accepted(client): + model = SimpleNamespace() + processor = SimpleNamespace() + config = SimpleNamespace(model_type="qwen2_vl") + result = SimpleNamespace( + text='{"answer": 42}', + prompt_tokens=8, + generation_tokens=4, + total_tokens=12, + prompt_tps=10.0, + generation_tps=5.0, + peak_memory=0.1, + ) + + with ( + patch.object( + server, "get_cached_model", return_value=(model, processor, config) + ), + patch.object( + server, "apply_chat_template", return_value="prompt" + ) as mock_template, + patch.object(server, "generate", return_value=result), + ): + response = client.post( + "/chat/completions", + json={ + "model": "demo", + "messages": [{"role": "user", "content": "Give me JSON"}], + "response_format": {"type": "json_object"}, + }, + ) + + assert response.status_code == 200 + # The first message passed to apply_chat_template should be the injected system msg + chat_messages = mock_template.call_args.args[2] + assert chat_messages[0]["role"] == "system" + assert "json" in chat_messages[0]["content"].lower() + + +def test_responses_json_mode_accepted(client): + model = SimpleNamespace() + processor = SimpleNamespace() + config = SimpleNamespace(model_type="qwen2_vl") + result = SimpleNamespace( + text='{"answer": 42}', + prompt_tokens=8, + generation_tokens=4, + total_tokens=12, + ) + + with ( + patch.object( + server, "get_cached_model", return_value=(model, processor, config) + ), + patch.object( + server, "apply_chat_template", return_value="prompt" + ) as mock_template, + patch.object(server, "generate", return_value=result), + ): + response = client.post( + "/responses", + json={ + "model": "demo", + "input": "Give me JSON", + "response_format": {"type": "json_object"}, + }, + ) + + assert response.status_code == 200 + # The first message passed to apply_chat_template should be the injected system msg + chat_messages = mock_template.call_args.args[2] + assert chat_messages[0]["role"] == "system" + assert "json" in chat_messages[0]["content"].lower()