-
Notifications
You must be signed in to change notification settings - Fork 252
Fix: Improve HTTP API structure and async handler usage (#569) #2063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| """Health, readiness, and device stats HTTP routes.""" | ||
|
|
||
| from fastapi import APIRouter, Depends | ||
| from typing import Any, Optional | ||
| from starlette.responses import JSONResponse | ||
|
|
||
| from inference.core.env import DOCKER_SOCKET_PATH | ||
| from inference.core.managers.metrics import get_container_stats | ||
| from inference.core.utils.container import is_docker_socket_mounted | ||
|
|
||
|
|
||
| def create_health_router(model_init_state: Optional[Any] = None) -> APIRouter: | ||
| router = APIRouter() | ||
|
|
||
| @router.get("/device/stats", summary="Device/container statistics") | ||
| def device_stats(): | ||
| not_configured_error_message = { | ||
| "error": "Device statistics endpoint is not enabled.", | ||
| "hint": ( | ||
| "Mount the Docker socket and point its location when running the docker " | ||
| "container to collect device stats " | ||
| "(i.e. `docker run ... -v /var/run/docker.sock:/var/run/docker.sock " | ||
| "-e DOCKER_SOCKET_PATH=/var/run/docker.sock ...`)." | ||
| ), | ||
| } | ||
| if not DOCKER_SOCKET_PATH: | ||
| return JSONResponse( | ||
| status_code=404, | ||
| content=not_configured_error_message, | ||
| ) | ||
| if not is_docker_socket_mounted(docker_socket_path=DOCKER_SOCKET_PATH): | ||
| return JSONResponse( | ||
| status_code=500, | ||
| content=not_configured_error_message, | ||
| ) | ||
|
|
||
| container_stats = get_container_stats(docker_socket_path=DOCKER_SOCKET_PATH) | ||
| return JSONResponse(status_code=200, content=container_stats) | ||
|
|
||
| @router.get("/readiness", status_code=200) | ||
| def readiness(state: Any = Depends(lambda: model_init_state)): | ||
| """Readiness endpoint for Kubernetes readiness probe.""" | ||
| if state is None: | ||
| return {"status": "ready"} | ||
| with state.lock: | ||
| if state.is_ready: | ||
| return {"status": "ready"} | ||
| return JSONResponse( | ||
| content={"status": "not ready"}, status_code=503 | ||
| ) | ||
|
|
||
| @router.get("/healthz", status_code=200) | ||
| def healthz(): | ||
| """Health endpoint for Kubernetes liveness probe.""" | ||
| return {"status": "healthy"} | ||
|
|
||
| return router | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,303 @@ | ||
| """Roboflow trained-model inference HTTP routes (/infer/*).""" | ||
|
|
||
| from typing import List, Optional, Union | ||
|
|
||
| from fastapi import APIRouter, BackgroundTasks, Query, Request, HTTPException | ||
|
|
||
| from inference.core import logger | ||
| from inference.core.entities.requests.inference import ( | ||
| ClassificationInferenceRequest, | ||
| DepthEstimationRequest, | ||
| InferenceRequest, | ||
| InstanceSegmentationInferenceRequest, | ||
| KeypointsDetectionInferenceRequest, | ||
| ObjectDetectionInferenceRequest, | ||
| LMMInferenceRequest, | ||
| SemanticSegmentationInferenceRequest, | ||
| ) | ||
| from inference.core.entities.responses.inference import ( | ||
| ClassificationInferenceResponse, | ||
| DepthEstimationResponse, | ||
| InferenceResponse, | ||
| InstanceSegmentationInferenceResponse, | ||
| KeypointsDetectionInferenceResponse, | ||
| ObjectDetectionInferenceResponse, | ||
| MultiLabelClassificationInferenceResponse, | ||
| StubResponse, | ||
| LMMInferenceResponse, | ||
| SemanticSegmentationInferenceResponse, | ||
| ) | ||
| from inference.core.env import DEPTH_ESTIMATION_ENABLED, LMM_ENABLED, MOONDREAM2_ENABLED | ||
| from inference.core.interfaces.http.error_handlers import with_route_exceptions | ||
| from inference.core.interfaces.http.orjson_utils import orjson_response | ||
| from inference.core.managers.base import ModelManager | ||
| from inference.models.aliases import resolve_roboflow_model_alias | ||
| from inference.usage_tracking.collector import usage_collector | ||
|
|
||
|
|
||
| def create_inference_router( | ||
| model_manager: ModelManager, | ||
| ) -> APIRouter: | ||
| router = APIRouter() | ||
|
|
||
| def process_inference_request( | ||
| inference_request: InferenceRequest, | ||
| countinference: Optional[bool] = None, | ||
| service_secret: Optional[str] = None, | ||
| **kwargs, | ||
| ) -> InferenceResponse: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From what I see, the docstrings were removed. Maybe it's a good opportunity to add the docstrings for the public functions in the modules you created. WDYT? |
||
| de_aliased_model_id = resolve_roboflow_model_alias( | ||
| model_id=inference_request.model_id | ||
| ) | ||
| model_manager.add_model( | ||
| de_aliased_model_id, | ||
| inference_request.api_key, | ||
| countinference=countinference, | ||
| service_secret=service_secret, | ||
| ) | ||
| resp = model_manager.infer_from_request_sync( | ||
| de_aliased_model_id, | ||
| inference_request, | ||
| **kwargs, | ||
| ) | ||
| return orjson_response(resp) | ||
|
|
||
| @router.post( | ||
| "/infer/object_detection", | ||
| response_model=Union[ | ||
| ObjectDetectionInferenceResponse, | ||
| List[ObjectDetectionInferenceResponse], | ||
| StubResponse, | ||
| ], | ||
| summary="Object detection infer", | ||
| description="Run inference with the specified object detection model", | ||
| response_model_exclude_none=True, | ||
| ) | ||
| @with_route_exceptions | ||
| @usage_collector("request") | ||
| def infer_object_detection( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned previously let's bring back docstrings to the functions |
||
| inference_request: ObjectDetectionInferenceRequest, | ||
| background_tasks: BackgroundTasks, | ||
| countinference: Optional[bool] = None, | ||
| service_secret: Optional[str] = None, | ||
| ): | ||
| logger.debug("Reached /infer/object_detection") | ||
| return process_inference_request( | ||
| inference_request, | ||
| active_learning_eligible=True, | ||
| background_tasks=background_tasks, | ||
| countinference=countinference, | ||
| service_secret=service_secret, | ||
| ) | ||
|
|
||
| @router.post( | ||
| "/infer/instance_segmentation", | ||
| response_model=Union[InstanceSegmentationInferenceResponse, StubResponse], | ||
| summary="Instance segmentation infer", | ||
| description="Run inference with the specified instance segmentation model", | ||
| ) | ||
| @with_route_exceptions | ||
| @usage_collector("request") | ||
| def infer_instance_segmentation( | ||
| inference_request: InstanceSegmentationInferenceRequest, | ||
| background_tasks: BackgroundTasks, | ||
| countinference: Optional[bool] = None, | ||
| service_secret: Optional[str] = None, | ||
| ): | ||
| logger.debug("Reached /infer/instance_segmentation") | ||
| return process_inference_request( | ||
| inference_request, | ||
| active_learning_eligible=True, | ||
| background_tasks=background_tasks, | ||
| countinference=countinference, | ||
| service_secret=service_secret, | ||
| ) | ||
|
|
||
| @router.post( | ||
| "/infer/semantic_segmentation", | ||
| response_model=Union[SemanticSegmentationInferenceResponse, StubResponse], | ||
| summary="Semantic segmentation infer", | ||
| description="Run inference with the specified semantic segmentation model", | ||
| ) | ||
| @with_route_exceptions | ||
| @usage_collector("request") | ||
| def infer_semantic_segmentation( | ||
| inference_request: SemanticSegmentationInferenceRequest, | ||
| background_tasks: BackgroundTasks, | ||
| countinference: Optional[bool] = None, | ||
| service_secret: Optional[str] = None, | ||
| ): | ||
| logger.debug("Reached /infer/semantic_segmentation") | ||
| return process_inference_request( | ||
| inference_request, | ||
| active_learning_eligible=True, | ||
| background_tasks=background_tasks, | ||
| countinference=countinference, | ||
| service_secret=service_secret, | ||
| ) | ||
|
|
||
| @router.post( | ||
| "/infer/classification", | ||
| response_model=Union[ | ||
| ClassificationInferenceResponse, | ||
| MultiLabelClassificationInferenceResponse, | ||
| StubResponse, | ||
| ], | ||
| summary="Classification infer", | ||
| description="Run inference with the specified classification model", | ||
| ) | ||
| @with_route_exceptions | ||
| @usage_collector("request") | ||
| def infer_classification( | ||
| inference_request: ClassificationInferenceRequest, | ||
| background_tasks: BackgroundTasks, | ||
| countinference: Optional[bool] = None, | ||
| service_secret: Optional[str] = None, | ||
| ): | ||
| logger.debug("Reached /infer/classification") | ||
| return process_inference_request( | ||
| inference_request, | ||
| active_learning_eligible=True, | ||
| background_tasks=background_tasks, | ||
| countinference=countinference, | ||
| service_secret=service_secret, | ||
| ) | ||
|
|
||
| @router.post( | ||
| "/infer/keypoints_detection", | ||
| response_model=Union[KeypointsDetectionInferenceResponse, StubResponse], | ||
| summary="Keypoints detection infer", | ||
| description="Run inference with the specified keypoints detection model", | ||
| ) | ||
| @with_route_exceptions | ||
| @usage_collector("request") | ||
| def infer_keypoints( | ||
| inference_request: KeypointsDetectionInferenceRequest, | ||
| countinference: Optional[bool] = None, | ||
| service_secret: Optional[str] = None, | ||
| ): | ||
| logger.debug("Reached /infer/keypoints_detection") | ||
| return process_inference_request( | ||
| inference_request, | ||
| countinference=countinference, | ||
| service_secret=service_secret, | ||
| ) | ||
|
|
||
| if LMM_ENABLED or MOONDREAM2_ENABLED: | ||
| @router.post( | ||
| "/infer/lmm", | ||
| response_model=Union[ | ||
| LMMInferenceResponse, | ||
| List[LMMInferenceResponse], | ||
| StubResponse, | ||
| ], | ||
| summary="Large multi-modal model infer", | ||
| description="Run inference with the specified large multi-modal model", | ||
| response_model_exclude_none=True, | ||
| ) | ||
| @with_route_exceptions | ||
| @usage_collector("request") | ||
| def infer_lmm( | ||
| inference_request: LMMInferenceRequest, | ||
| countinference: Optional[bool] = None, | ||
| service_secret: Optional[str] = None, | ||
| ): | ||
| """Run inference with the specified large multi-modal model. | ||
|
|
||
| Args: | ||
| inference_request (LMMInferenceRequest): The request containing the necessary details for LMM inference. | ||
|
|
||
| Returns: | ||
| Union[LMMInferenceResponse, List[LMMInferenceResponse]]: The response containing the inference results. | ||
| """ | ||
| logger.debug(f"Reached /infer/lmm") | ||
| return process_inference_request( | ||
| inference_request, | ||
| countinference=countinference, | ||
| service_secret=service_secret, | ||
| ) | ||
|
|
||
| @router.post( | ||
| "/infer/lmm/{model_id:path}", | ||
| response_model=Union[ | ||
| LMMInferenceResponse, | ||
| List[LMMInferenceResponse], | ||
| StubResponse, | ||
| ], | ||
| summary="Large multi-modal model infer with model ID in path", | ||
| description="Run inference with the specified large multi-modal model. Model ID is specified in the URL path (can contain slashes).", | ||
| response_model_exclude_none=True, | ||
| ) | ||
| @with_route_exceptions | ||
| @usage_collector("request") | ||
| def infer_lmm_with_model_id( | ||
| model_id: str, | ||
| inference_request: LMMInferenceRequest, | ||
| countinference: Optional[bool] = None, | ||
| service_secret: Optional[str] = None, | ||
| ): | ||
| """Run inference with the specified large multi-modal model. | ||
|
|
||
| The model_id can be specified in the URL path. If model_id is also provided | ||
| in the request body, it must match the path parameter. | ||
|
|
||
| Args: | ||
| model_id (str): The model identifier from the URL path. | ||
| inference_request (LMMInferenceRequest): The request containing the necessary details for LMM inference. | ||
|
|
||
| Returns: | ||
| Union[LMMInferenceResponse, List[LMMInferenceResponse]]: The response containing the inference results. | ||
|
|
||
| Raises: | ||
| HTTPException: If model_id in path and request body don't match. | ||
| """ | ||
| logger.debug(f"Reached /infer/lmm/{model_id}") | ||
|
|
||
| # Validate model_id consistency between path and request body | ||
| if ( | ||
| inference_request.model_id is not None | ||
| and inference_request.model_id != model_id | ||
| ): | ||
| raise HTTPException( | ||
| status_code=400, | ||
| detail=f"Model ID mismatch: path specifies '{model_id}' but request body specifies '{inference_request.model_id}'", | ||
| ) | ||
|
|
||
| # Set the model_id from path if not in request body | ||
| inference_request.model_id = model_id | ||
|
|
||
| return process_inference_request( | ||
| inference_request, | ||
| countinference=countinference, | ||
| service_secret=service_secret, | ||
| ) | ||
|
|
||
| if DEPTH_ESTIMATION_ENABLED: | ||
|
|
||
| @router.post( | ||
| "/infer/depth-estimation", | ||
| response_model=DepthEstimationResponse, | ||
| summary="Depth Estimation", | ||
| description="Run the depth estimation model to generate a depth map.", | ||
| ) | ||
| @with_route_exceptions | ||
| def depth_estimation( | ||
| inference_request: DepthEstimationRequest, | ||
| countinference: Optional[bool] = None, | ||
| service_secret: Optional[str] = None, | ||
| ): | ||
| logger.debug("Reached /infer/depth-estimation") | ||
| depth_model_id = inference_request.model_id | ||
| model_manager.add_model( | ||
| depth_model_id, | ||
| inference_request.api_key, | ||
| countinference=countinference, | ||
| service_secret=service_secret, | ||
| ) | ||
| response = model_manager.infer_from_request_sync( | ||
| depth_model_id, inference_request | ||
| ) | ||
| return response | ||
|
|
||
| return router | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whit this modular approach we could also improve the docs by doing:
This would group the doc endpoints nicely