Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 161 additions & 22 deletions mlx_vlm/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import asyncio
import gc
import json
import os
Expand All @@ -12,7 +13,7 @@

import mlx.core as mx
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from huggingface_hub import scan_cache_dir
Expand Down Expand Up @@ -43,12 +44,49 @@

DEFAULT_SERVER_HOST = "0.0.0.0"
DEFAULT_SERVER_PORT = 8080
DEFAULT_REQUEST_TIMEOUT = 300


def get_request_timeout() -> int:
"""Request timeout in seconds. Must be > 0."""
try:
value = int(os.environ.get("REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT))
except ValueError:
raise ValueError(
f"REQUEST_TIMEOUT must be a valid integer, got: "
f"{os.environ.get('REQUEST_TIMEOUT')!r}"
)
if value <= 0:
raise ValueError(f"REQUEST_TIMEOUT must be > 0, got {value}")
return value


def get_prefill_step_size():
return int(os.environ.get("PREFILL_STEP_SIZE", DEFAULT_PREFILL_STEP_SIZE))


def get_max_context_tokens() -> int:
"""Maximum prompt tokens before rejecting a request. 0 means no limit."""
value = int(os.environ.get("MAX_CONTEXT_TOKENS", 0))
if value < 0:
raise ValueError(f"MAX_CONTEXT_TOKENS must be >= 0, got {value}")
return value


def check_context_length(prompt: str, processor, max_context: int) -> None:
"""Raise HTTP 400 if the tokenized prompt exceeds *max_context* tokens."""
if max_context <= 0:
return
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
token_count = len(tokenizer.encode(prompt, add_special_tokens=False))
if token_count > max_context:
raise HTTPException(
status_code=400,
detail=f"Prompt length ({token_count} tokens) exceeds maximum context "
f"window ({max_context} tokens). Reduce your prompt or increase --max-context-tokens.",
)


def get_quantized_kv_bits(model: str):
kv_bits = float(os.environ.get("KV_BITS", 0))
if kv_bits == 0:
Expand Down Expand Up @@ -382,6 +420,10 @@ class OpenAIRequest(GenerationParams, TemplateParams):
stream: bool = Field(
False, description="Whether to stream the response chunk by chunk."
)
response_format: Optional[dict] = Field(
None,
description='Output format: {"type": "text"} or {"type": "json_object"}.',
)

def generation_kwargs(self) -> dict[str, Any]:
kwargs = self.dump_kwargs("max_output_tokens")
Expand Down Expand Up @@ -602,6 +644,10 @@ class UsageStats(OpenAIUsage):

class ChatRequest(GenerationRequest):
messages: List[ChatMessage]
response_format: Optional[dict] = Field(
None,
description='Output format: {"type": "text"} or {"type": "json_object"}.',
)


class ChatChoice(BaseModel):
Expand Down Expand Up @@ -630,6 +676,29 @@ class ChatStreamChunk(BaseModel):
usage: Optional[UsageStats]


def resolve_response_format(messages, response_format):
"""Inject JSON instruction if json_object format requested.

Returns a new list — the original messages list is not mutated.
"""
if not response_format:
return messages
fmt_type = response_format.get("type", "text")
if fmt_type not in ("text", "json_object"):
raise HTTPException(
status_code=400,
detail=f"Unsupported response_format type: '{fmt_type}'. "
"Supported types are 'text' and 'json_object'.",
)
if fmt_type == "json_object":
json_instruction = (
"You must respond with valid JSON only. "
"Do not include any text outside the JSON object."
)
return [{"role": "system", "content": json_instruction}] + messages
return messages


def build_generation_kwargs(
request: Any,
template_kwargs: dict[str, Any],
Expand Down Expand Up @@ -837,6 +906,10 @@ def run_openai(prompt, img_url,system, stream=False, max_output_tokens=512, mode
print("no input")
raise HTTPException(status_code=400, detail="Missing input.")

chat_messages = resolve_response_format(
chat_messages, openai_request.response_format
)

template_kwargs = openai_request.template_kwargs()
formatted_prompt = apply_chat_template(
processor,
Expand All @@ -845,6 +918,7 @@ def run_openai(prompt, img_url,system, stream=False, max_output_tokens=512, mode
num_images=len(images),
**template_kwargs,
)
check_context_length(formatted_prompt, processor, get_max_context_tokens())
generation_kwargs = build_generation_kwargs(openai_request, template_kwargs)

generated_at = datetime.now().timestamp()
Expand Down Expand Up @@ -982,15 +1056,35 @@ async def stream_generator():
else:
# Non-streaming response
try:
# Use generate from generate.py
result = generate(
model=model,
processor=processor,
prompt=formatted_prompt,
image=images,
verbose=False, # stats are passed in the response
**generation_kwargs,
)
# Use generate from generate.py, with request timeout.
# NOTE: wait_for cancels the future but cannot interrupt the
# sync generate() running in the thread pool. The thread will
# run to completion; only the await is aborted.
timeout = get_request_timeout()
loop = asyncio.get_running_loop()
try:
result = await asyncio.wait_for(
loop.run_in_executor(
None,
lambda: generate(
model=model,
processor=processor,
prompt=formatted_prompt,
image=images,
verbose=False,
**generation_kwargs,
),
),
timeout=timeout,
)
except asyncio.TimeoutError:
print(f"[cancellation] /responses generation timed out after {timeout}s.")
mx.clear_cache()
gc.collect()
raise HTTPException(
status_code=504,
detail=f"Generation timed out after {timeout} seconds.",
)
# Clean up resources
mx.clear_cache()
gc.collect()
Expand Down Expand Up @@ -1051,7 +1145,7 @@ async def stream_generator():
"/chat/completions", response_model=None
) # Response model handled dynamically based on stream flag
@app.post("/v1/chat/completions", response_model=None, include_in_schema=False)
async def chat_completions_endpoint(request: ChatRequest):
async def chat_completions_endpoint(request: ChatRequest, raw_request: Request):
"""
Generate text based on a prompt and optional images.
Prompt must be a list of chat messages, including system, user, and assistant messages.
Expand Down Expand Up @@ -1103,6 +1197,9 @@ async def chat_completions_endpoint(request: ChatRequest):
tool_parser_type = _infer_tool_parser(tokenizer.chat_template)
if tool_parser_type is not None:
tool_module = load_tool_module(tool_parser_type)
processed_messages = resolve_response_format(
processed_messages, request.response_format
)
template_kwargs = request.template_kwargs()
formatted_prompt = apply_chat_template(
processor,
Expand All @@ -1113,6 +1210,7 @@ async def chat_completions_endpoint(request: ChatRequest):
tools=tools,
**template_kwargs,
)
check_context_length(formatted_prompt, processor, get_max_context_tokens())
generation_kwargs = build_generation_kwargs(request, template_kwargs)

if request.stream:
Expand All @@ -1134,6 +1232,13 @@ async def stream_generator():
output_text = ""
request_id = f"chatcmpl-{uuid.uuid4()}"
for chunk in token_iterator:
# Check if client disconnected
if await raw_request.is_disconnected():
print("[cancellation] Client disconnected during /chat/completions streaming, aborting generation.")
if token_iterator is not None:
token_iterator.close()
return

if chunk is None or not hasattr(chunk, "text"):
print("Warning: Received unexpected chunk format:", chunk)
continue
Expand Down Expand Up @@ -1199,6 +1304,12 @@ async def stream_generator():

yield "data: [DONE]\n\n"

except asyncio.CancelledError:
print("[cancellation] /chat/completions stream cancelled (client disconnect).")
if token_iterator is not None:
token_iterator.close()
raise

except Exception as e:
print(f"Error during stream generation: {e}")
traceback.print_exc()
Expand All @@ -1223,17 +1334,37 @@ async def stream_generator():
else:
# Non-streaming response
try:
# Use generate from generate.py
gen_result = generate(
model=model,
processor=processor,
prompt=formatted_prompt,
image=images,
audio=audio,
verbose=False, # Keep API output clean
vision_cache=model_cache.get("vision_cache"),
**generation_kwargs,
)
# Use generate from generate.py, with request timeout.
# NOTE: wait_for cancels the future but cannot interrupt the
# sync generate() running in the thread pool. The thread will
# run to completion; only the await is aborted.
timeout = get_request_timeout()
loop = asyncio.get_running_loop()
try:
gen_result = await asyncio.wait_for(
loop.run_in_executor(
None,
lambda: generate(
model=model,
processor=processor,
prompt=formatted_prompt,
image=images,
audio=audio,
verbose=False,
vision_cache=model_cache.get("vision_cache"),
**generation_kwargs,
),
),
timeout=timeout,
)
except asyncio.TimeoutError:
print(f"[cancellation] /chat/completions generation timed out after {timeout}s.")
mx.clear_cache()
gc.collect()
raise HTTPException(
status_code=504,
detail=f"Generation timed out after {timeout} seconds.",
)
# Clean up resources
mx.clear_cache()
gc.collect()
Expand Down Expand Up @@ -1433,6 +1564,13 @@ def main():
default=DEFAULT_QUANTIZED_KV_START,
help="Start index (of token) for the quantized KV cache.",
)
parser.add_argument(
"--max-context-tokens",
type=int,
default=0,
help="Maximum context window in tokens. Requests exceeding this are rejected. "
"0 means no limit. (default: %(default)s)",
)
parser.add_argument(
"--reload",
action="store_true",
Expand All @@ -1454,6 +1592,7 @@ def main():
os.environ["KV_QUANT_SCHEME"] = args.kv_quant_scheme
os.environ["MAX_KV_SIZE"] = str(args.max_kv_size)
os.environ["QUANTIZED_KV_START"] = str(args.quantized_kv_start)
os.environ["MAX_CONTEXT_TOKENS"] = str(args.max_context_tokens)

uvicorn.run(
"mlx_vlm.server:app",
Expand Down
Loading