diff --git a/LICENSE b/LICENSE index 76bb742b00..a800576cea 100644 --- a/LICENSE +++ b/LICENSE @@ -2,6 +2,6 @@ LICENSE.core (Apache 2.0) applies to all files in this repository except for files in or under any directory that contains a superseding license file (such as the models located in `inference/models/` which are governed by their own individual licenses and -the files and folders in the `inference/enterprise/` directory which are +the files and folders in the `inference/enterprise/` and `inference_cli/lib/enterprise/` directories which are governed by the Roboflow Enterprise License located at `inference/enterprise/LICENSE.txt`). \ No newline at end of file diff --git a/inference_cli/lib/enterprise/LICENSE.txt b/inference_cli/lib/enterprise/LICENSE.txt new file mode 100644 index 0000000000..bdddaf3ad7 --- /dev/null +++ b/inference_cli/lib/enterprise/LICENSE.txt @@ -0,0 +1,39 @@ +The Roboflow Enterprise License (the “Enterprise License”) +Copyright (c) 2023 Roboflow Inc. + +With regard to the Roboflow Software: + +This software and associated documentation files (the "Software") may only be +used in production, if you (and any entity that you represent) have accepted +and are following the terms of a separate Roboflow Enterprise agreement +that governs how you use the software. + +Subject to the foregoing sentence, you are free to modify this Software and publish +patches to the Software. You agree that Roboflow and/or its licensors (as applicable) +retain all right, title and interest in and to all such modifications and/or patches, +and all such modifications and/or patches may only be used, copied, modified, +displayed, distributed, or otherwise exploited with a valid Roboflow Enterprise +license for the correct number of seats, devices, inferences, and other +usage metrics specified therein. + +Notwithstanding the foregoing, you may copy and modify the Software for development +and testing purposes, without requiring a subscription. You agree that Roboflow and/or +its licensors (as applicable) retain all right, title and interest in and to all +such modifications. You are not granted any other rights beyond what is expressly +stated herein. Subject to the foregoing, it is forbidden to copy, merge, publish, +distribute, sublicense, and/or sell the Software. + +The full text of this Enterprise License shall be included in all copies or +substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +For all third party components incorporated into the Roboflow Software, those +components are licensed under the original license provided by the owner of the +applicable component. diff --git a/inference_cli/lib/enterprise/__init__.py b/inference_cli/lib/enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/inference_cli/lib/enterprise/core.py b/inference_cli/lib/enterprise/core.py new file mode 100644 index 0000000000..52d90303d6 --- /dev/null +++ b/inference_cli/lib/enterprise/core.py @@ -0,0 +1,8 @@ +import typer + +from inference_cli.lib.enterprise.inference_compiler.cli.core import ( + inference_compiler_app, +) + +enterprise_app = typer.Typer(help="Roboflow Enterprise commands") +enterprise_app.add_typer(inference_compiler_app, name="inference-compiler") diff --git a/inference_cli/lib/enterprise/inference_compiler/__init__.py b/inference_cli/lib/enterprise/inference_compiler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/inference_cli/lib/enterprise/inference_compiler/adapters/__init__.py b/inference_cli/lib/enterprise/inference_compiler/adapters/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/inference_cli/lib/enterprise/inference_compiler/adapters/models_service.py b/inference_cli/lib/enterprise/inference_compiler/adapters/models_service.py new file mode 100644 index 0000000000..84dd8d6aa7 --- /dev/null +++ b/inference_cli/lib/enterprise/inference_compiler/adapters/models_service.py @@ -0,0 +1,465 @@ +from typing import Any, Dict, List, Literal, Optional + +import backoff +import requests +from pydantic import BaseModel, Field +from requests import Timeout + +from inference_cli.lib.enterprise.inference_compiler.constants import ( + ROBOFLOW_API_HOST, + ROBOFLOW_API_KEY, +) +from inference_cli.lib.enterprise.inference_compiler.errors import ( + RetryError, + RuntimeConfigurationError, +) +from inference_cli.lib.enterprise.inference_compiler.utils.http import ( + handle_response_errors, +) + + +class SignedURLDetails(BaseModel): + type: Literal["signed-url-details-v1"] + upload_url: str = Field(alias="uploadUrl") + method: str = Field(alias="method") + extension_headers: dict = Field(alias="extensionHeaders") + max_file_size: int = Field(alias="maxFileSize") + + class Config: + populate_by_name = True + + +class ExternalFileUploadSpecs(BaseModel): + type: Literal["external-file-upload-specs-v1"] + file_handle: str = Field(alias="fileHandle") + signed_url_details: SignedURLDetails = Field(alias="signedUrlDetails") + + class Config: + populate_by_name = True + + +class ModelPackageRegistrationResponse(BaseModel): + model_id: str = Field(alias="modelId") + model_package_id: str = Field(alias="modelPackageId") + file_upload_specs: List[ExternalFileUploadSpecs] = Field(alias="filesUploadSpecs") + + class Config: + populate_by_name = True + + +class FileConfirmation(BaseModel): + file_handle: str = Field(alias="fileHandle") + md5_hash: Optional[str] = Field(alias="md5Hash", default=None) + + class Config: + populate_by_name = True + + +class ExternalPublicTRTTimingCompilationEntryV1(BaseModel): + type: Literal["external-public-trt-timing-cache-entry-v1"] + cache_key: str = Field(alias="cacheKey") + compilation_features: Dict[str, Any] = Field(alias="compilationFeatures") + file_handle: str = Field(alias="fileHandle") + download_url: str = Field(alias="downloadUrl") + md5_hash: Optional[str] = Field(alias="md5Hash", default=None) + + class Config: + populate_by_name = True + + +class ExternalPrivateTRTTimingCompilationEntryV1(BaseModel): + type: Literal["external-private-trt-timing-cache-entry-v1"] + cache_key: str = Field(alias="cacheKey") + compilation_features: Dict[str, Any] = Field(alias="compilationFeatures") + file_handle: str = Field(alias="fileHandle") + download_url: str = Field(alias="downloadUrl") + md5_hash: Optional[str] = Field(alias="md5Hash", default=None) + + class Config: + populate_by_name = True + + +class PrivateTRTTimingCacheEntryRegistrationResults(BaseModel): + cache_key: str = Field(alias="cacheKey") + upload_specs: ExternalFileUploadSpecs = Field(alias="uploadSpecs") + + class Config: + populate_by_name = True + + +class ExternalPrivateTRTTimingCacheListEntryV1(BaseModel): + type: Literal["external-private-trt-timing-cache-list-entry-v1"] + cache_key: str = Field(alias="cacheKey") + compilation_features: Dict[str, Any] = Field(alias="compilationFeatures") + sealed: bool + + class Config: + populate_by_name = True + + +class PrivateTRTTimingCacheEntriesList(BaseModel): + cache_entries: List[ExternalPrivateTRTTimingCacheListEntryV1] = Field( + alias="cacheEntries" + ) + next_page_token: Optional[str] = Field(alias="nextPageToken", default=None) + + class Config: + populate_by_name = True + + +class ModelsServiceClient: + @classmethod + def init( + cls, + api_key: Optional[str] = None, + ) -> "ModelsServiceClient": + if api_key is None: + api_key = ROBOFLOW_API_KEY + if api_key is None: + raise RuntimeConfigurationError( + "Could not initialise Models Service client without Roboflow API key. " + "Set the key explicitly or use environment variable `ROBOFLOW_API_KEY`. If you need help getting " + "your Roboflow API key, " + "visit: https://docs.roboflow.com/developer/authentication/find-your-roboflow-api-key" + ) + return cls( + api_host=ROBOFLOW_API_HOST, + api_key=api_key, + ) + + def __init__( + self, + api_host: str, + api_key: str, + ): + self._api_host = api_host + self._api_key = api_key + + @backoff.on_exception( + backoff.fibo, + exception=RetryError, + max_tries=3, + max_value=5, + ) + def register_model_package( + self, + model_id: str, + package_manifest: dict, + file_handles: List[str], + model_features: Optional[dict] = None, + ): + try: + payload = { + "modelId": model_id, + "packageManifest": package_manifest, + "fileHandles": file_handles, + } + if model_features: + payload["modelFeatures"] = model_features + response = requests.post( + f"{self._api_host}/models/v1/external/model-packages/register", + json=payload, + headers=self._add_auth_headers(), + ) + except (ConnectionError, Timeout, requests.exceptions.ConnectionError): + raise RetryError(f"Connectivity error") + handle_response_errors(response=response) + return ModelPackageRegistrationResponse.model_validate(response.json()) + + @backoff.on_exception( + backoff.fibo, + exception=RetryError, + max_tries=3, + max_value=5, + ) + def confirm_model_package_artefacts( + self, + model_id: str, + model_package_id: str, + confirmations: List[FileConfirmation], + seal_model_package: Optional[bool] = None, + ) -> None: + try: + payload: Dict[str, Any] = { + "modelId": model_id, + "modelPackageId": model_package_id, + "confirmations": [c.model_dump(by_alias=True) for c in confirmations], + } + if seal_model_package: + payload["sealModelPackage"] = seal_model_package + response = requests.post( + f"{self._api_host}/models/v1/external/model-packages/artefacts/confirm", + json=payload, + headers=self._add_auth_headers(), + ) + except (ConnectionError, Timeout, requests.exceptions.ConnectionError): + raise RetryError(f"Connectivity error") + handle_response_errors(response=response) + return None + + @backoff.on_exception( + backoff.fibo, + exception=RetryError, + max_tries=3, + max_value=5, + ) + def add_model_package_artefacts( + self, + model_id: str, + model_package_id: str, + file_handles: List[str], + ) -> ModelPackageRegistrationResponse: + try: + payload = { + "modelId": model_id, + "modelPackageId": model_package_id, + "fileHandles": file_handles, + } + response = requests.post( + f"{self._api_host}/models/v1/external/model-packages/artefacts/add", + json=payload, + headers=self._add_auth_headers(), + ) + except (ConnectionError, Timeout, requests.exceptions.ConnectionError): + raise RetryError(f"Connectivity error") + handle_response_errors(response=response) + return ModelPackageRegistrationResponse.model_validate(response.json()) # type: ignore + + @backoff.on_exception( + backoff.fibo, + exception=RetryError, + max_tries=3, + max_value=5, + ) + def remove_model_package_artefacts( + self, + model_id: str, + model_package_id: str, + file_handles: List[str], + ) -> None: + try: + payload = { + "modelId": model_id, + "modelPackageId": model_package_id, + "fileHandles": file_handles, + } + response = requests.post( + f"{self._api_host}/models/v1/external/model-packages/artefacts/remove", + json=payload, + headers=self._add_auth_headers(), + ) + except (ConnectionError, Timeout, requests.exceptions.ConnectionError): + raise RetryError(f"Connectivity error") + handle_response_errors(response=response) + return None + + @backoff.on_exception( + backoff.fibo, + exception=RetryError, + max_tries=3, + max_value=5, + ) + def seal_model_package(self, model_id: str, package_id: str) -> None: + payload: Dict[str, Any] = {"modelId": model_id, "modelPackageId": package_id} + try: + response = requests.post( + f"{self._api_host}/models/v1/external/model-packages/seal", + json=payload, + headers=self._add_auth_headers(), + ) + except (ConnectionError, Timeout, requests.exceptions.ConnectionError): + raise RetryError(f"Connectivity error") + handle_response_errors(response=response) + return None + + @backoff.on_exception( + backoff.fibo, + exception=RetryError, + max_tries=3, + max_value=5, + ) + def un_seal_model_package(self, model_id: str, model_package_id: str) -> None: + payload: Dict[str, Any] = { + "modelId": model_id, + "modelPackageId": model_package_id, + } + try: + response = requests.post( + f"{self._api_host}/models/v1/external/model-packages/un-seal", + json=payload, + headers=self._add_auth_headers(), + ) + except (ConnectionError, Timeout, requests.exceptions.ConnectionError): + raise RetryError(f"Connectivity error") + handle_response_errors(response=response) + return None + + @backoff.on_exception( + backoff.fibo, + exception=RetryError, + max_tries=3, + max_value=5, + ) + def delete_model_package(self, model_id: str, model_package_id: str) -> None: + payload: Dict[str, Any] = { + "modelId": model_id, + "modelPackageId": model_package_id, + } + try: + response = requests.post( + f"{self._api_host}/models/v1/external/model-packages/delete", + json=payload, + headers=self._add_auth_headers(), + ) + except (ConnectionError, Timeout, requests.exceptions.ConnectionError): + raise RetryError(f"Connectivity error") + handle_response_errors(response=response) + return None + + @backoff.on_exception( + backoff.fibo, + exception=RetryError, + max_tries=3, + max_value=5, + ) + def un_delete_model_package(self, model_id: str, model_package_id: str) -> None: + payload: Dict[str, Any] = { + "modelId": model_id, + "modelPackageId": model_package_id, + } + try: + response = requests.post( + f"{self._api_host}/models/v1/external/model-packages/un-delete", + json=payload, + headers=self._add_auth_headers(), + ) + except (ConnectionError, Timeout, requests.exceptions.ConnectionError): + raise RetryError(f"Connectivity error") + handle_response_errors(response=response) + return None + + @backoff.on_exception( + backoff.fibo, + exception=RetryError, + max_tries=3, + max_value=5, + ) + def get_public_trt_timing_cache( + self, compilation_features: Dict[str, Any] + ) -> ExternalPublicTRTTimingCompilationEntryV1: + try: + response = requests.post( + f"{self._api_host}/models/v1/external/trt-compilation/timing-cache/public/get", + json={"compilationFeatures": compilation_features}, + headers=self._add_auth_headers(), + ) + except (ConnectionError, Timeout, requests.exceptions.ConnectionError): + raise RetryError(f"Connectivity error") + handle_response_errors(response=response) + return ExternalPublicTRTTimingCompilationEntryV1.model_validate( # type: ignore + response.json()["cacheEntry"] + ) + + @backoff.on_exception( + backoff.fibo, + exception=RetryError, + max_tries=3, + max_value=5, + ) + def get_private_trt_timing_cache( + self, compilation_features: Dict[str, Any] + ) -> ExternalPrivateTRTTimingCompilationEntryV1: + try: + response = requests.post( + f"{self._api_host}/models/v1/external/trt-compilation/timing-cache/private/get", + json={"compilationFeatures": compilation_features}, + headers=self._add_auth_headers(), + ) + except (ConnectionError, Timeout, requests.exceptions.ConnectionError): + raise RetryError(f"Connectivity error") + handle_response_errors(response=response) + return ExternalPrivateTRTTimingCompilationEntryV1.model_validate( # type: ignore + response.json()["cacheEntry"] + ) + + @backoff.on_exception( + backoff.fibo, + exception=RetryError, + max_tries=3, + max_value=5, + ) + def register_private_trt_timing_cache( + self, + compilation_features: Dict[str, Any], + ) -> PrivateTRTTimingCacheEntryRegistrationResults: + try: + response = requests.post( + f"{self._api_host}/models/v1/external/trt-compilation/timing-cache/private/register", + json={"compilationFeatures": compilation_features}, + headers=self._add_auth_headers(), + ) + except (ConnectionError, Timeout, requests.exceptions.ConnectionError): + raise RetryError(f"Connectivity error") + handle_response_errors(response=response) + return PrivateTRTTimingCacheEntryRegistrationResults.model_validate( # type: ignore + response.json() + ) + + @backoff.on_exception( + backoff.fibo, + exception=RetryError, + max_tries=3, + max_value=5, + ) + def confirm_private_trt_timing_cache_upload( + self, cache_key: str, confirmation: FileConfirmation + ) -> None: + try: + payload: Dict[str, Any] = { + "cacheKey": cache_key, + "confirmation": confirmation.model_dump(by_alias=True), + } + response = requests.post( + f"{self._api_host}/models/v1/external/trt-compilation/timing-cache/private/confirm", + json=payload, + headers=self._add_auth_headers(), + ) + except (ConnectionError, Timeout, requests.exceptions.ConnectionError): + raise RetryError(f"Connectivity error") + handle_response_errors(response=response) + return None + + @backoff.on_exception( + backoff.fibo, + exception=RetryError, + max_tries=3, + max_value=5, + ) + def list_private_timing_cache_entries( + self, + page_size: Optional[int] = None, + start_after: Optional[str] = None, + ) -> PrivateTRTTimingCacheEntriesList: + try: + query: Dict[str, Any] = {} + if page_size is not None: + query["pageSize"] = page_size + if start_after is not None: + query["startAfter"] = start_after + response = requests.get( + f"{self._api_host}/models/v1/external/trt-compilation/timing-cache/private/list", + params=query, + headers=self._add_auth_headers(), + ) + except (ConnectionError, Timeout, requests.exceptions.ConnectionError): + raise RetryError(f"Connectivity error") + handle_response_errors(response=response) + return PrivateTRTTimingCacheEntriesList.model_validate(response.json()) # type: ignore + + def _add_auth_headers( + self, headers: Optional[Dict[str, str]] = None + ) -> Dict[str, str]: + if headers is None: + headers = {} + headers["Authorization"] = f"Bearer {self._api_key}" + return headers diff --git a/inference_cli/lib/enterprise/inference_compiler/cli/__init__.py b/inference_cli/lib/enterprise/inference_compiler/cli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/inference_cli/lib/enterprise/inference_compiler/cli/core.py b/inference_cli/lib/enterprise/inference_compiler/cli/core.py new file mode 100644 index 0000000000..414aec5a61 --- /dev/null +++ b/inference_cli/lib/enterprise/inference_compiler/cli/core.py @@ -0,0 +1,257 @@ +from enum import Enum +from typing import Annotated, Optional + +import typer +from rich.console import Console + +from inference_cli.lib.container_adapter import get_image, pull_image +from inference_cli.lib.env import ROBOFLOW_API_KEY + +inference_compiler_app = typer.Typer(name="Inference compiler") + + +class CompilationMode(str, Enum): + AUTO = "auto" + CONTAINER = "container" + PYTHON = "python" + + +@inference_compiler_app.callback() +def compiler_callback(): + pass + + +@inference_compiler_app.command(name="compile-model") +def compile_model( + model_id: Annotated[ + str, + typer.Option( + "--model-id", + "-m", + help="Model ID in format project/version.", + ), + ], + api_key: Annotated[ + Optional[str], + typer.Option( + "--api-key", + "-a", + help="Roboflow API key for your workspace. If not given - env variable `ROBOFLOW_API_KEY` will be used", + ), + ] = None, + debug_mode: Annotated[ + bool, + typer.Option( + "--debug-mode/--no-debug-mode", + help="Flag enabling errors stack traces to be displayed (helpful for debugging)", + ), + ] = False, + trt_forward_compatible: Annotated[ + bool, + typer.Option( + "--trt-forward-compatible/--no-trt-forward-compatible", + help="Flag to decide if forward-compatibility mode in TRT compilation should be enabled", + ), + ] = False, + trt_same_cc_compatible: Annotated[ + bool, + typer.Option( + "--trt-same-cc-compatible/--no-trt-same-cc-compatible", + help="Flag to decide if engine should be compiled to be compatible with devices sharing the same CUDA CC " + "to the one running compilation procedure", + ), + ] = False, + compilation_mode: Annotated[ + CompilationMode, + typer.Option( + "--compilation-mode", + help="Selection of compilation mode - `container` runs the procedure inside `inference` server, " + "`python` runs in-process. `auto` (default) inspect environment dependencies to verify if " + "the procedure can be run in-process, if not - offloading to the server.", + ), + ] = CompilationMode.AUTO, + image: Annotated[ + Optional[str], + typer.Option( + "--image", + help="Point specific docker image you would like to run with command (useful for development of custom " + "builds of inference server)", + ), + ] = None, + use_local_images: Annotated[ + bool, + typer.Option( + "--use-local-images/--not-use-local-images", + help="Flag to allow using local images (if set False image is always attempted to be pulled)", + ), + ] = False, +) -> None: + console = Console() + console.print( + "You are running component licensed under Roboflow Enterprise License - please acknowledge the " + "terms of use: https://github.com/roboflow/inference/blob/main/inference/enterprise/LICENSE.txt", + ) + if api_key is None: + api_key = ROBOFLOW_API_KEY + try: + if len(model_id.split(" ")) > 1: + raise ValueError( + "Format of model_id is incorrect - expected string without whitespaces" + ) + if compilation_to_run_in_container( + compilation_mode=compilation_mode, console=console + ): + run_compilation_in_container( + model_id=model_id, + api_key=api_key, + trt_forward_compatible=trt_forward_compatible, + trt_same_cc_compatible=trt_same_cc_compatible, + console=console, + image=image, + use_local_images=use_local_images, + ) + else: + run_compilation_in_python( + model_id=model_id, + api_key=api_key, + trt_forward_compatible=trt_forward_compatible, + trt_same_cc_compatible=trt_same_cc_compatible, + console=console, + ) + + except KeyboardInterrupt: + print("Command interrupted.") + return + except Exception as error: + if debug_mode: + raise error + typer.echo(f"Command failed. Cause: {error}") + raise typer.Exit(code=1) + + +def compilation_to_run_in_container( + compilation_mode: CompilationMode, + console: Console, +) -> bool: + if compilation_mode == CompilationMode.CONTAINER: + return True + if compilation_mode == CompilationMode.PYTHON: + return False + try: + import inference_models + except Exception as error: + console.print( + "Inference compiler running in `auto` mode could not import `inference-models`, which is required " + f"to compile package in process - offloading to container. Error: {error}", + ) + return True + try: + import tensorrt + except Exception as error: + console.print( + "Inference compiler running in `auto` mode could not import `tensorrt`, which is required " + f"to compile package in process - offloading to container. Error: {error}", + ) + return True + return False + + +def run_compilation_in_container( + model_id: str, + api_key: Optional[str] = None, + trt_forward_compatible: bool = False, + trt_same_cc_compatible: bool = False, + console: Optional[Console] = None, + image: Optional[str] = None, + use_local_images: bool = False, +) -> None: + import docker + + if image is None: + image = get_image() + if "-cpu" in image: + raise ValueError( + "Attempted to run compilation using `inference-server` CPU image, which does not support TRT compilation. " + "This error may be result of pointing invalid docker image with `--image` parameter or image " + "auto-selection choice, due to lack of GPU detected." + ) + is_gpu = "gpu" in image and "jetson" not in image + is_jetson = "jetson" in image + device_requests = None + privileged = False + docker_run_kwargs = {} + if is_gpu: + device_requests = [ + docker.types.DeviceRequest(device_ids=["all"], capabilities=[["gpu"]]) + ] + if is_jetson: + privileged = True + docker_run_kwargs = {"runtime": "nvidia"} + pull_image(image, use_local_images=use_local_images) + console.print("Starting model compilation inside docker container") + command = build_container_command( + model_id=model_id, + api_key=api_key, + trt_forward_compatible=trt_forward_compatible, + trt_same_cc_compatible=trt_same_cc_compatible, + ) + docker_client = docker.from_env() + docker_client.containers.run( + image=image, + command=command.split(" "), + privileged=privileged, + detach=True, + device_requests=device_requests, + security_opt=["no-new-privileges"] if not is_jetson else None, + cap_drop=["ALL"] if not is_jetson else None, + cap_add=( + (["NET_BIND_SERVICE"] + (["SYS_ADMIN"] if is_gpu else [])) + if not is_jetson + else None + ), + read_only=not is_jetson, + volumes={"/tmp": {"bind": "/tmp", "mode": "rw"}}, + network_mode="bridge", + ipc_mode="private" if not is_jetson else None, + **docker_run_kwargs, + ) + + +def build_container_command( + model_id: str, + api_key: Optional[str] = None, + trt_forward_compatible: bool = False, + trt_same_cc_compatible: bool = False, +) -> str: + command = ( + f"inference enterprise inference-compile compile-model --model-id {model_id}" + ) + if api_key: + command += f" --api-key {api_key}" + command += f" --trt-forward-compatible {stringify_boolean(trt_forward_compatible)}" + command += f" --trt-same-cc {stringify_boolean(trt_same_cc_compatible)}" + return command + + +def stringify_boolean(value: bool) -> str: + if value: + return "true" + return "false" + + +def run_compilation_in_python( + model_id: str, + api_key: Optional[str] = None, + trt_forward_compatible: bool = False, + trt_same_cc_compatible: bool = False, + console: Optional[Console] = None, +) -> None: + from inference_cli.lib.enterprise.inference_compiler.core import compiler + + compiler.compile_model( + model_id=model_id, + api_key=api_key, + trt_forward_compatible=trt_forward_compatible, + trt_same_cc_compatible=trt_same_cc_compatible, + console=console, + ) diff --git a/inference_cli/lib/enterprise/inference_compiler/constants.py b/inference_cli/lib/enterprise/inference_compiler/constants.py new file mode 100644 index 0000000000..fef7a89479 --- /dev/null +++ b/inference_cli/lib/enterprise/inference_compiler/constants.py @@ -0,0 +1,74 @@ +import os + +REQUEST_TIMEOUT_ENV = "REQUEST_TIMEOUT" +DEFAULT_REQUEST_TIMEOUT = "60" +PROD_ENVIRONMENT_NAME = "prod" +ROBOFLOW_ENVIRONMENT = os.getenv("ROBOFLOW_ENVIRONMENT", PROD_ENVIRONMENT_NAME) +ROBOFLOW_API_HOST = os.getenv( + "ROBOFLOW_API_HOST", + ( + "https://api.roboflow.com" + if ROBOFLOW_ENVIRONMENT == PROD_ENVIRONMENT_NAME + else "https://api.roboflow.one" + ), +) +ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY", None) +HTTP_CODES_TO_RETRY = {408, 429, 500, 502, 503, 504} +YOLO_MODELS_MIN_DYNAMIC_BATCH_SIZE = int( + os.getenv("YOLO_MODELS_MIN_DYNAMIC_BATCH_SIZE", "1") +) +YOLO_MODELS_OPT_DYNAMIC_BATCH_SIZE = int( + os.getenv("YOLO_MODELS_OPT_DYNAMIC_BATCH_SIZE", "8") +) +YOLO_MODELS_MAX_DYNAMIC_BATCH_SIZE = int( + os.getenv("YOLO_MODELS_MAX_DYNAMIC_BATCH_SIZE", "16") +) +YOLO_MODELS_WORKSPACE_SIZE = int(os.getenv("YOLO_MODELS_WORKSPACE_SIZE", "12")) +RFDETR_MODELS_MIN_DYNAMIC_BATCH_SIZE = int( + os.getenv("RFDETR_MODELS_MIN_DYNAMIC_BATCH_SIZE", "1") +) +RFDETR_MODELS_OPT_DYNAMIC_BATCH_SIZE = int( + os.getenv("RFDETR_MODELS_OPT_DYNAMIC_BATCH_SIZE", "8") +) +RFDETR_MODELS_MAX_DYNAMIC_BATCH_SIZE = int( + os.getenv("RFDETR_MODELS_MAX_DYNAMIC_BATCH_SIZE", "16") +) +RFDETR_MODELS_WORKSPACE_SIZE = int(os.getenv("RFDETR_MODELS_WORKSPACE_SIZE", "12")) +RESNET_MODELS_MIN_DYNAMIC_BATCH_SIZE = int( + os.getenv("RESNET_MODELS_MIN_DYNAMIC_BATCH_SIZE", "1") +) +RESNET_MODELS_OPT_DYNAMIC_BATCH_SIZE = int( + os.getenv("RESNET_MODELS_OPT_DYNAMIC_BATCH_SIZE", "8") +) +RESNET_MODELS_MAX_DYNAMIC_BATCH_SIZE = int( + os.getenv("RESNET_MODELS_MAX_DYNAMIC_BATCH_SIZE", "128") +) +RESNET_MODELS_WORKSPACE_SIZE = int(os.getenv("RESNET_MODELS_WORKSPACE_SIZE", "12")) +VIT_MODELS_MIN_DYNAMIC_BATCH_SIZE = int( + os.getenv("VIT_MODELS_MIN_DYNAMIC_BATCH_SIZE", "1") +) +VIT_MODELS_OPT_DYNAMIC_BATCH_SIZE = int( + os.getenv("VIT_MODELS_OPT_DYNAMIC_BATCH_SIZE", "8") +) +VIT_MODELS_MAX_DYNAMIC_BATCH_SIZE = int( + os.getenv("VIT_MODELS_MAX_DYNAMIC_BATCH_SIZE", "32") +) +VIT_MODELS_WORKSPACE_SIZE = int(os.getenv("VIT_MODELS_WORKSPACE_SIZE", "12")) +DEEP_LAB_MODELS_MIN_DYNAMIC_BATCH_SIZE = int( + os.getenv("DEEP_LAB_MODELS_MIN_DYNAMIC_BATCH_SIZE", "1") +) +DEEP_LAB_MODELS_OPT_DYNAMIC_BATCH_SIZE = int( + os.getenv("DEEP_LAB_MODELS_OPT_DYNAMIC_BATCH_SIZE", "8") +) +DEEP_LAB_MODELS_MAX_DYNAMIC_BATCH_SIZE = int( + os.getenv("DEEP_LAB_MODELS_MAX_DYNAMIC_BATCH_SIZE", "32") +) +DEEP_LAB_MODELS_WORKSPACE_SIZE = int(os.getenv("DEEP_LAB_MODELS_WORKSPACE_SIZE", "12")) +KEYPOINT_DETECTION_TASK_TYPE = "keypoint-detection" +INFERENCE_CONFIG_FILE = "inference_config.json" +CLASS_NAMES_FILE = "class_names.txt" +WEIGHTS_ONNX_FILE = "weights.onnx" +KEYPOINTS_METADATA_FILE = "keypoints_metadata.json" +TRT_CONFIG_FILE = "trt_config.json" +ENGINE_PLAN_FILE = "engine.plan" +MODEL_CONFIG_FILE = "model_config.json" diff --git a/inference_cli/lib/enterprise/inference_compiler/core/__init__.py b/inference_cli/lib/enterprise/inference_compiler/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/inference_cli/lib/enterprise/inference_compiler/core/compilation_handlers/__init__.py b/inference_cli/lib/enterprise/inference_compiler/core/compilation_handlers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/inference_cli/lib/enterprise/inference_compiler/core/compilation_handlers/default.py b/inference_cli/lib/enterprise/inference_compiler/core/compilation_handlers/default.py new file mode 100644 index 0000000000..f514e7ed5d --- /dev/null +++ b/inference_cli/lib/enterprise/inference_compiler/core/compilation_handlers/default.py @@ -0,0 +1,358 @@ +import logging +import os.path +import tempfile +from typing import Callable, Dict, Literal, Optional, Tuple + +from rich.console import Console + +from inference_cli.lib.enterprise.inference_compiler.adapters.models_service import ( + ModelPackageRegistrationResponse, + ModelsServiceClient, +) +from inference_cli.lib.enterprise.inference_compiler.constants import ( + CLASS_NAMES_FILE, + ENGINE_PLAN_FILE, + INFERENCE_CONFIG_FILE, + KEYPOINT_DETECTION_TASK_TYPE, + KEYPOINTS_METADATA_FILE, + MODEL_CONFIG_FILE, + TRT_CONFIG_FILE, + WEIGHTS_ONNX_FILE, +) +from inference_cli.lib.enterprise.inference_compiler.core.compilation_handlers.utils import ( + download_model_package, + execute_compilation, + get_training_input_size, + register_model_package_artefacts, + safe_negotiate_model_packages, +) +from inference_cli.lib.enterprise.inference_compiler.core.entities import ( + CompilationConfig, + TRTConfig, +) +from inference_cli.lib.enterprise.inference_compiler.errors import ( + AlreadyCompiledError, + CompiledPackageRegistrationError, + ModelVerificationError, +) +from inference_cli.lib.enterprise.inference_compiler.utils.file_system import ( + calculate_local_file_md5, + dump_json, + read_json, +) +from inference_cli.lib.enterprise.inference_compiler.utils.logging import ( + print_to_console, +) +from inference_models.weights_providers.entities import ModelMetadata + + +def compile_and_register_default_model( + model_metadata: ModelMetadata, + models_service_client: ModelsServiceClient, + compilation_directory: str, + trt_forward_compatible: bool, + trt_same_cc_compatible: bool, + console: Optional[Console], + compilation_config: CompilationConfig, +) -> None: + ( + package_with_static_batch_size, + package_with_dynamic_batch_size, + ) = safe_negotiate_model_packages( + model_metadata=model_metadata, + ) + print_to_console(message="Downloading source model artefacts...", console=console) + source_packages_directory = os.path.join(compilation_directory, "source_packages") + expected_files = [INFERENCE_CONFIG_FILE, CLASS_NAMES_FILE, WEIGHTS_ONNX_FILE] + if model_metadata.task_type == KEYPOINT_DETECTION_TASK_TYPE: + expected_files.append(KEYPOINTS_METADATA_FILE) + if package_with_dynamic_batch_size is not None: + package_files = download_model_package( + model_architecture=model_metadata.model_architecture, + task_type=model_metadata.task_type, + model_package=package_with_dynamic_batch_size, + target_directory=source_packages_directory, + expected_files=expected_files, + verify_model=compilation_config.verify_model, + ) + else: + package_files = download_model_package( + model_architecture=model_metadata.model_architecture, + task_type=model_metadata.task_type, + model_package=package_with_static_batch_size, + target_directory=source_packages_directory, + expected_files=expected_files, + verify_model=compilation_config.verify_model, + ) + training_size = get_training_input_size( + inference_config_path=package_files[INFERENCE_CONFIG_FILE] + ) + compilation_output_dir = os.path.join(compilation_directory, "compilation_output") + os.makedirs(compilation_output_dir, exist_ok=True) + if package_with_dynamic_batch_size is None: + static_bs_fp32_engine_directory = os.path.join( + compilation_output_dir, "static_bs_fp32" + ) + compile_and_register_default_model_trt_variant( + models_service_client=models_service_client, + model_metadata=model_metadata, + compilation_directory=static_bs_fp32_engine_directory, + local_files=package_files, + training_size=training_size, + precision="fp32", + workspace_size_gb=compilation_config.workspace_size_gb, + trt_forward_compatible=trt_forward_compatible, + same_compute_compatibility=trt_same_cc_compatible, + verify_model=compilation_config.verify_model, + console=console, + ) + static_bs_fp16_engine_directory = os.path.join( + compilation_output_dir, "static_bs_fp16" + ) + compile_and_register_default_model_trt_variant( + models_service_client=models_service_client, + model_metadata=model_metadata, + compilation_directory=static_bs_fp16_engine_directory, + local_files=package_files, + training_size=training_size, + precision="fp16", + workspace_size_gb=compilation_config.workspace_size_gb, + trt_forward_compatible=trt_forward_compatible, + same_compute_compatibility=trt_same_cc_compatible, + verify_model=compilation_config.verify_model, + console=console, + ) + return None + dynamic_bs_fp32_engine_directory = os.path.join( + compilation_output_dir, "dynamic_bs_fp32" + ) + compile_and_register_default_model_trt_variant( + models_service_client=models_service_client, + model_metadata=model_metadata, + compilation_directory=dynamic_bs_fp32_engine_directory, + local_files=package_files, + training_size=training_size, + precision="fp32", + workspace_size_gb=compilation_config.workspace_size_gb, + min_batch_size=compilation_config.min_batch_size, + opt_batch_size=compilation_config.opt_batch_size, + max_batch_size=compilation_config.max_batch_size, + trt_forward_compatible=trt_forward_compatible, + same_compute_compatibility=trt_same_cc_compatible, + verify_model=compilation_config.verify_model, + console=console, + ) + dynamic_bs_fp16_engine_directory = os.path.join( + compilation_output_dir, "dynamic_bs_fp16" + ) + compile_and_register_default_model_trt_variant( + models_service_client=models_service_client, + model_metadata=model_metadata, + compilation_directory=dynamic_bs_fp16_engine_directory, + local_files=package_files, + training_size=training_size, + precision="fp16", + workspace_size_gb=compilation_config.workspace_size_gb, + min_batch_size=compilation_config.min_batch_size, + opt_batch_size=compilation_config.opt_batch_size, + max_batch_size=compilation_config.max_batch_size, + trt_forward_compatible=trt_forward_compatible, + same_compute_compatibility=trt_same_cc_compatible, + verify_model=compilation_config.verify_model, + console=console, + ) + + +def compile_and_register_default_model_trt_variant( + models_service_client: ModelsServiceClient, + model_metadata: ModelMetadata, + compilation_directory: str, + local_files: Dict[str, str], + training_size: Tuple[int, int], + precision: Literal["fp32", "fp16"], + workspace_size_gb: int, + min_batch_size: Optional[int] = None, + opt_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + trt_forward_compatible: bool = False, + same_compute_compatibility: bool = False, + verify_model: Optional[Callable[[str], None]] = None, + console: Optional[Console] = None, +) -> None: + print_to_console( + message=f"Building TRT engine - precision={precision}", console=console + ) + try: + file_handles_to_register = [ + CLASS_NAMES_FILE, + INFERENCE_CONFIG_FILE, + TRT_CONFIG_FILE, + ENGINE_PLAN_FILE, + ] + if KEYPOINTS_METADATA_FILE in local_files: + file_handles_to_register.append(KEYPOINTS_METADATA_FILE) + model_architecture: str + engine_path, trt_config, registration_response = execute_compilation( + models_service_client=models_service_client, + model_id=model_metadata.model_id, + model_architecture=model_metadata.model_architecture, + task_type=model_metadata.task_type, + model_variant=model_metadata.model_variant, + file_handles_to_register=file_handles_to_register, + compilation_directory=compilation_directory, + onnx_path=local_files[WEIGHTS_ONNX_FILE], + precision=precision, + model_input_size=training_size, + workspace_size_gb=workspace_size_gb, + min_batch_size=min_batch_size, + opt_batch_size=opt_batch_size, + max_batch_size=max_batch_size, + trt_version_compatible=trt_forward_compatible, + same_compute_compatibility=same_compute_compatibility, + console=console, + ) + except AlreadyCompiledError: + print_to_console( + message="Model package already registered - skipping", console=console + ) + return None + if verify_model is not None: + print_to_console(message="Verification of the artefacts...", console=console) + verify_model_package( + model_metadata=model_metadata, + model_package_id=registration_response.model_package_id, + trt_config=trt_config, + inference_config_path=local_files[INFERENCE_CONFIG_FILE], + class_names_path=local_files[CLASS_NAMES_FILE], + engine_path=engine_path, + verify_model=verify_model, + keypoints_metadata_path=local_files.get(KEYPOINTS_METADATA_FILE), + ) + register_default_model_package_artefacts( + registration_response=registration_response, + trt_config=trt_config, + inference_config_path=local_files[INFERENCE_CONFIG_FILE], + class_names_path=local_files[CLASS_NAMES_FILE], + keypoints_metadata_path=local_files.get(KEYPOINTS_METADATA_FILE), + engine_path=engine_path, + compilation_directory=compilation_directory, + models_service_client=models_service_client, + ) + print_to_console( + message="Successfully trained and registered model package", console=console + ) + + +def verify_model_package( + model_metadata: ModelMetadata, + model_package_id: str, + trt_config: TRTConfig, + inference_config_path: str, + class_names_path: str, + engine_path: str, + verify_model: Callable[[str], None], + keypoints_metadata_path: Optional[str], +) -> None: + try: + with tempfile.TemporaryDirectory() as tmp_dir: + logging.info(f"Verifying model package {model_package_id}...") + adjusted_inference_config_path = os.path.join( + tmp_dir, INFERENCE_CONFIG_FILE + ) + prepare_adjusted_inference_config( + inference_config_path=inference_config_path, + target_path=adjusted_inference_config_path, + ) + trt_config_path = os.path.join(tmp_dir, TRT_CONFIG_FILE) + dump_json(path=trt_config_path, content=trt_config.model_dump()) + local_class_names_path = os.path.join(tmp_dir, CLASS_NAMES_FILE) + os.symlink(class_names_path, local_class_names_path) + local_engine_path = os.path.join(tmp_dir, ENGINE_PLAN_FILE) + os.symlink(engine_path, local_engine_path) + model_config_path = os.path.join(tmp_dir, MODEL_CONFIG_FILE) + model_config = { + "model_architecture": model_metadata.model_architecture, + "task_type": model_metadata.task_type, + "backend_type": "trt", + } + dump_json(path=model_config_path, content=model_config) + if keypoints_metadata_path is not None: + local_keypoints_metadata_path = os.path.join( + tmp_dir, KEYPOINTS_METADATA_FILE + ) + os.symlink(keypoints_metadata_path, local_keypoints_metadata_path) + verify_model(tmp_dir) + logging.info(f"Model package {model_package_id} verified.") + except ModelVerificationError as error: + raise error + except Exception as error: + raise ModelVerificationError( + "Could not successfully verify correctness of model compilation" + ) from error + + +def register_default_model_package_artefacts( + registration_response: ModelPackageRegistrationResponse, + trt_config: TRTConfig, + inference_config_path: str, + class_names_path: str, + keypoints_metadata_path: Optional[str], + engine_path: str, + compilation_directory: str, + models_service_client: ModelsServiceClient, +) -> None: + try: + adjusted_inference_config_path = os.path.join( + compilation_directory, "adjusted_inference_config.json" + ) + prepare_adjusted_inference_config( + inference_config_path=inference_config_path, + target_path=adjusted_inference_config_path, + ) + trt_config_path = os.path.join(compilation_directory, TRT_CONFIG_FILE) + dump_json(path=trt_config_path, content=trt_config.model_dump()) + local_files_mapping = { + INFERENCE_CONFIG_FILE: ( + adjusted_inference_config_path, + calculate_local_file_md5(file_path=adjusted_inference_config_path), + ), + CLASS_NAMES_FILE: ( + class_names_path, + calculate_local_file_md5(file_path=class_names_path), + ), + TRT_CONFIG_FILE: ( + trt_config_path, + calculate_local_file_md5(file_path=trt_config_path), + ), + ENGINE_PLAN_FILE: ( + engine_path, + calculate_local_file_md5(file_path=engine_path), + ), + } + if keypoints_metadata_path is not None: + local_files_mapping[KEYPOINTS_METADATA_FILE] = ( + keypoints_metadata_path, + calculate_local_file_md5(file_path=keypoints_metadata_path), + ) + except Exception as error: + logging.exception( + f"Could not register artefacts for package {registration_response.model_package_id}" + ) + raise CompiledPackageRegistrationError( + f"Could not register artefacts for package {registration_response.model_package_id}" + ) from error + register_model_package_artefacts( + registration_response=registration_response, + local_files_mapping=local_files_mapping, + models_service_client=models_service_client, + ) + + +def prepare_adjusted_inference_config( + inference_config_path: str, + target_path: str, +) -> None: + inference_config = read_json(inference_config_path) + inference_config["network_input"]["dynamic_spatial_size_supported"] = False + inference_config["network_input"]["dynamic_spatial_size_mode"] = None + dump_json(path=target_path, content=inference_config) diff --git a/inference_cli/lib/enterprise/inference_compiler/core/compilation_handlers/engine_builder.py b/inference_cli/lib/enterprise/inference_compiler/core/compilation_handlers/engine_builder.py new file mode 100644 index 0000000000..26c9cc6c59 --- /dev/null +++ b/inference_cli/lib/enterprise/inference_compiler/core/compilation_handlers/engine_builder.py @@ -0,0 +1,147 @@ +import logging +import os +from typing import Any, Dict, Literal, Optional, Tuple + +import tensorrt as trt + +from inference_cli.lib.enterprise.inference_compiler.core.compilation_handlers.timing_cache_manager import ( + TimingCacheManager, +) +from inference_cli.lib.enterprise.inference_compiler.errors import ( + InvalidNetworkInputsError, + NetworkParsingError, + QuantizationNotSupportedError, + TRTModelCompilationError, +) +from inference_models.logger import LOGGER +from inference_models.models.common.trt import InferenceTRTLogger + +LOGGER.setLevel(logging.INFO) + + +class EngineBuilder: + """ + Parses an ONNX graph and builds a TensorRT engine from it. + """ + + def __init__( + self, + workspace: int, + ): + self.trt_logger = InferenceTRTLogger() + trt.init_libnvinfer_plugins(self.trt_logger, namespace="") + self.builder = trt.Builder(self.trt_logger) + self.config = self.builder.create_builder_config() + self.config.set_memory_pool_limit( + trt.MemoryPoolType.WORKSPACE, workspace * (2**30) + ) + self.network: Optional[trt.tensorrt.INetworkDefinition] = None + self.parser: Optional[trt.OnnxParser] = None + self.cache_manager: Optional[TimingCacheManager] = None + + def set_timing_cache_manager(self, cache_manager: TimingCacheManager) -> None: + self.cache_manager = cache_manager + + def create_network(self, onnx_path: str) -> None: + """ + Parse the ONNX graph and create the corresponding TensorRT network definition. + :param onnx_path: The path to the ONNX graph to load. + """ + self.network = self.builder.create_network(0) + self.parser = trt.OnnxParser(self.network, self.trt_logger) + + onnx_path = os.path.realpath(onnx_path) + with open(onnx_path, "rb") as f: + if not self.parser.parse(f.read()): + LOGGER.error("Failed to load ONNX file: {}".format(onnx_path)) + for error in range(self.parser.num_errors): + LOGGER.error(self.parser.get_error(error)) + raise NetworkParsingError("Could not parse ONNX file") + + network_inputs = [ + self.network.get_input(i) for i in range(self.network.num_inputs) + ] + network_outputs = [ + self.network.get_output(i) for i in range(self.network.num_outputs) + ] + LOGGER.info("Network Description") + for network_input in network_inputs: + LOGGER.info( + "Input '{}' with shape {} and dtype {}".format( + network_input.name, network_input.shape, network_input.dtype + ) + ) + for network_output in network_outputs: + LOGGER.info( + "Output '{}' with shape {} and dtype {}".format( + network_output.name, network_output.shape, network_output.dtype + ) + ) + + def get_static_batch_size_of_input(self) -> int: + network_input = self._get_image_input() + try: + return int(network_input.shape[0]) + except ValueError as error: + raise InvalidNetworkInputsError( + f"Expected the input to have static batch size, detected shape: {network_input.shape}" + ) from error + + def create_engine( + self, + engine_path: str, + precision: Literal["fp32", "fp16"], + input_size: Tuple[int, int], + dynamic_batch_sizes: Optional[Tuple[int, int, int]] = None, + trt_version_compatible: bool = False, + same_compute_compatibility: bool = False, + ) -> None: + if self.cache_manager: + cache_bytes = self.cache_manager.get_cache_for_features() + cache = self.config.create_timing_cache(cache_bytes) + self.config.set_timing_cache(cache, ignore_mismatch=False) + engine_path = os.path.abspath(engine_path) + engine_dir = os.path.dirname(engine_path) + os.makedirs(engine_dir, exist_ok=True) + LOGGER.info("Building {} Engine in {}".format(precision, engine_path)) + network_input = self._get_image_input() + input_name = network_input.name + if precision == "fp16": + if not self.builder.platform_has_fast_fp16: + raise QuantizationNotSupportedError("FP16 quantization not supported") + self.config.set_flag(trt.BuilderFlag.FP16) + if trt_version_compatible: + self.config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) + if same_compute_compatibility: + self.config.hardware_compatibility_level = ( + trt.HardwareCompatibilityLevel.SAME_COMPUTE_CAPABILITY + ) + profile = self.builder.create_optimization_profile() + if dynamic_batch_sizes: + bs_min, bs_opt, bs_max = dynamic_batch_sizes + h, w = input_size + profile.set_shape( + input_name, (bs_min, 3, h, w), (bs_opt, 3, h, w), (bs_max, 3, h, w) + ) + self.config.add_optimization_profile(profile) + engine_bytes = self.builder.build_serialized_network(self.network, self.config) + if engine_bytes is None: + raise TRTModelCompilationError("Failed to create TRT engine") + with open(engine_path, "wb") as f: + LOGGER.info("Serializing engine to file: {:}".format(engine_path)) + f.write(engine_bytes) + if self.cache_manager: + cache = self.config.get_timing_cache() + self.cache_manager.save_cache_for_features(cache=cache.serialize()) + + def _get_image_input(self) -> trt.ITensor: + if self.network is None: + raise TRTModelCompilationError( + "Attempted to get network input before parsing the model" + ) + network_inputs = [ + self.network.get_input(i) for i in range(self.network.num_inputs) + ] + if len(network_inputs) != 1: + raise InvalidNetworkInputsError("Detected network with multiple inputs") + return network_inputs[0] diff --git a/inference_cli/lib/enterprise/inference_compiler/core/compilation_handlers/timing_cache_manager.py b/inference_cli/lib/enterprise/inference_compiler/core/compilation_handlers/timing_cache_manager.py new file mode 100644 index 0000000000..a37dd7c40a --- /dev/null +++ b/inference_cli/lib/enterprise/inference_compiler/core/compilation_handlers/timing_cache_manager.py @@ -0,0 +1,132 @@ +import logging +import os.path +import tempfile +from typing import Any, Dict, Union + +from inference_cli.lib.enterprise.inference_compiler.adapters.models_service import ( + ExternalPrivateTRTTimingCompilationEntryV1, + ExternalPublicTRTTimingCompilationEntryV1, + FileConfirmation, + ModelsServiceClient, +) +from inference_cli.lib.enterprise.inference_compiler.errors import RequestError +from inference_cli.lib.enterprise.inference_compiler.utils.file_system import ( + calculate_local_file_md5, + read_bytes, + write_bytes, +) +from inference_cli.lib.enterprise.inference_compiler.utils.http import ( + upload_file_to_cloud, +) +from inference_models.utils.download import download_files_to_directory + +TIMING_CACHE_ROOT = "/tmp/timing-cache" + + +class TimingCacheManager: + @classmethod + def init( + cls, + models_service_client: ModelsServiceClient, + compilation_features: Dict[str, Any], + ) -> "TimingCacheManager": + return cls( + cache_root=TIMING_CACHE_ROOT, + models_service_client=models_service_client, + compilation_features=compilation_features, + ) + + def __init__( + self, + cache_root: str, + models_service_client: ModelsServiceClient, + compilation_features: Dict[str, Any], + ): + self._cache_root = cache_root + self._models_service_client = models_service_client + self._compilation_features = compilation_features + self._should_not_populate_private_cache = False + + def get_cache_for_features(self) -> bytes: + try: + compilation_features_specs = self._attempt_getting_cache_entry() + file_handle = compilation_features_specs.file_handle + download_url = compilation_features_specs.download_url + md5_hash = compilation_features_specs.md5_hash + download_results = download_files_to_directory( + target_dir=self._cache_root, + files_specs=[(file_handle, download_url, md5_hash)], + ) + logging.info( + f"TRT Timing cache hit for compilation features: {self._compilation_features}" + ) + return read_bytes(download_results[file_handle]) + except RequestError as error: + if error.status_code == 404: + logging.info( + f"TRT Timing cache miss for compilation features: {self._compilation_features}" + ) + else: + self._should_not_populate_private_cache = True + logging.warning( + f"Could not retrieve TRT timing cache entry from RF API - status: {error.status_code}, message: {error}" + ) + return b"" + except Exception: + self._should_not_populate_private_cache = True + logging.exception(f"Error in retrieving TRT timing compilation cache") + return b"" + + def save_cache_for_features(self, cache: bytes) -> None: + if self._should_not_populate_private_cache: + return None + try: + registration_response = ( + self._models_service_client.register_private_trt_timing_cache( + compilation_features=self._compilation_features + ) + ) + with tempfile.TemporaryDirectory() as tmp_dir: + cache_entry_path = os.path.join(tmp_dir, "local-copy-of-cache-entry") + write_bytes(path=cache_entry_path, content=cache) + cache_entry_md5 = calculate_local_file_md5(file_path=cache_entry_path) + upload_file_to_cloud( + file_path=cache_entry_path, + url=registration_response.upload_specs.signed_url_details.upload_url, + headers=registration_response.upload_specs.signed_url_details.extension_headers, + ) + confirmation = FileConfirmation( + file_handle=registration_response.upload_specs.file_handle, + md5_hash=cache_entry_md5, + ) + self._models_service_client.confirm_private_trt_timing_cache_upload( + cache_key=registration_response.cache_key, + confirmation=confirmation, + ) + logging.info( + f"TRT timing cache saved for compilation features: {self._compilation_features}" + ) + except RequestError as error: + if error.status_code == 409: + return None + except Exception: + logging.exception(f"Error in saving TRT timing compilation cache") + + def _attempt_getting_cache_entry( + self, + ) -> Union[ + ExternalPublicTRTTimingCompilationEntryV1, + ExternalPrivateTRTTimingCompilationEntryV1, + ]: + try: + features_specs = self._models_service_client.get_public_trt_timing_cache( + compilation_features=self._compilation_features + ) + self._should_not_populate_private_cache = True + return features_specs # type: ignore + except RequestError as error: + if error.status_code != 404: + raise error + return self._models_service_client.get_private_trt_timing_cache( # type: ignore + compilation_features=self._compilation_features + ) diff --git a/inference_cli/lib/enterprise/inference_compiler/core/compilation_handlers/utils.py b/inference_cli/lib/enterprise/inference_compiler/core/compilation_handlers/utils.py new file mode 100644 index 0000000000..6fb9a3cf36 --- /dev/null +++ b/inference_cli/lib/enterprise/inference_compiler/core/compilation_handlers/utils.py @@ -0,0 +1,380 @@ +import logging +import os +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union + +from rich.console import Console + +from inference_cli.lib.enterprise.inference_compiler.adapters.models_service import ( + FileConfirmation, + ModelPackageRegistrationResponse, + ModelsServiceClient, +) +from inference_cli.lib.enterprise.inference_compiler.constants import MODEL_CONFIG_FILE +from inference_cli.lib.enterprise.inference_compiler.core.compilation_handlers.engine_builder import ( + EngineBuilder, +) +from inference_cli.lib.enterprise.inference_compiler.core.compilation_handlers.timing_cache_manager import ( + TimingCacheManager, +) +from inference_cli.lib.enterprise.inference_compiler.core.entities import ( + GPUServerSpecsV1, + JetsonMachineSpecsV1, + TRTConfig, + TRTMachineType, + TRTModelPackageV1, +) +from inference_cli.lib.enterprise.inference_compiler.errors import ( + AlreadyCompiledError, + CompiledPackageRegistrationError, + CorruptedPackageError, + LackOfSourcePackageError, + ModelVerificationError, + PackageDownloadError, + PackageNegotiationError, + RequestError, +) +from inference_cli.lib.enterprise.inference_compiler.utils.file_system import ( + dump_json, + read_json, +) +from inference_cli.lib.enterprise.inference_compiler.utils.http import ( + upload_file_to_cloud, +) +from inference_cli.lib.enterprise.inference_compiler.utils.logging import ( + print_to_console, +) +from inference_models.errors import NoModelPackagesAvailableError +from inference_models.models.auto_loaders.auto_negotiation import ( + negotiate_model_packages, +) +from inference_models.runtime_introspection.core import x_ray_runtime_environment +from inference_models.utils.download import download_files_to_directory +from inference_models.weights_providers.entities import ( + BackendType, + ModelMetadata, + ModelPackageMetadata, + Quantization, +) + + +def safe_negotiate_model_packages( + model_metadata: ModelMetadata, + requested_backends: Union[BackendType, List[BackendType]] = BackendType.ONNX, + requested_quantization: Union[Quantization, List[Quantization]] = Quantization.FP32, + allow_untrusted_packages: bool = False, +) -> Tuple[ModelPackageMetadata, Optional[ModelPackageMetadata]]: + try: + matching_model_packages = negotiate_model_packages( + model_architecture=model_metadata.model_architecture, + task_type=model_metadata.task_type, + model_packages=model_metadata.model_packages, + requested_backends=requested_backends, + requested_quantization=requested_quantization, + allow_untrusted_packages=allow_untrusted_packages, + verbose=True, + ) + package_with_static_batch_size = select_package_with_static_batch_size( + model_packages=matching_model_packages + ) + package_with_dynamic_batch_size = select_package_with_dynamic_batch_size( + model_packages=matching_model_packages + ) + return package_with_static_batch_size, package_with_dynamic_batch_size + except LackOfSourcePackageError as error: + raise error + except NoModelPackagesAvailableError as error: + raise LackOfSourcePackageError( + "Could not find model package which could serve as compilation source." + ) from error + except Exception as error: + logging.exception( + f"Error when selecting model packages for compilation - {error}" + ) + raise PackageNegotiationError( + "Error when selecting model packages for compilation." + ) from error + + +def select_package_with_static_batch_size( + model_packages: List[ModelPackageMetadata], +) -> ModelPackageMetadata: + static_bs_models = [p for p in model_packages if not p.dynamic_batch_size_supported] + if len(static_bs_models) == 0: + raise LackOfSourcePackageError( + "Could not find model package with static batch size" + ) + return static_bs_models[0] + + +def select_package_with_dynamic_batch_size( + model_packages: List[ModelPackageMetadata], +) -> Optional[ModelPackageMetadata]: + dynamic_bs_models = [p for p in model_packages if p.dynamic_batch_size_supported] + if len(dynamic_bs_models) == 0: + return None + return dynamic_bs_models[0] + + +def download_model_package( + model_architecture: str, + task_type: str, + model_package: ModelPackageMetadata, + target_directory: str, + expected_files: List[str], + verify_model: Optional[Callable[[str], None]] = None, +) -> Dict[str, str]: + files_specs = [ + (a.file_handle, a.download_url, a.md5_hash) + for a in model_package.package_artefacts + ] + package_dir = os.path.join(target_directory, model_package.package_id) + file_mapping = { + a.file_handle: os.path.join(package_dir, a.file_handle) + for a in model_package.package_artefacts + } + if any(f not in file_mapping for f in expected_files): + raise CorruptedPackageError( + f"At least one of the files {expected_files} missing in model package {model_package.package_id}" + ) + try: + os.makedirs(package_dir, exist_ok=True) + download_files_to_directory( + target_dir=package_dir, + files_specs=files_specs, + verbose=True, + ) + except Exception as error: + logging.exception( + f"Error when downloading model package: {model_package.package_id}" + ) + raise PackageDownloadError( + f"Could not download model package - error {error}" + ) from error + if verify_model is not None: + try: + model_config_path = os.path.join(package_dir, MODEL_CONFIG_FILE) + model_config = { + "model_architecture": model_architecture, + "task_type": task_type, + "backend_type": model_package.backend.value, + } + dump_json(path=model_config_path, content=model_config) + verify_model(package_dir) + except ModelVerificationError as error: + raise error + except Exception as error: + raise ModelVerificationError( + "Could not successfully verify correctness of model compilation" + ) from error + return file_mapping + + +def get_training_input_size(inference_config_path: str) -> Tuple[int, int]: + try: + inference_config = read_json(path=inference_config_path) + dimensions = inference_config["network_input"]["training_input_size"] + return dimensions["height"], dimensions["width"] + except Exception as error: + raise CorruptedPackageError( + f"Could not get training input size from inference config - {error}" + ) + + +def execute_compilation( + models_service_client: ModelsServiceClient, + model_id: str, + model_architecture: str, + task_type: Optional[str], + model_variant: Optional[str], + file_handles_to_register: List[str], + compilation_directory: str, + onnx_path: str, + precision: Literal["fp32", "fp16"], + model_input_size: Tuple[int, int], + workspace_size_gb: int, + min_batch_size: Optional[int] = None, + opt_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + trt_version_compatible: bool = True, + same_compute_compatibility: bool = False, + registered_model_features: Optional[dict] = None, + console: Optional[Console] = None, +) -> Tuple[str, TRTConfig, ModelPackageRegistrationResponse]: + runtime_xray = x_ray_runtime_environment() + os.makedirs(compilation_directory, exist_ok=True) + engine_builder = EngineBuilder(workspace=workspace_size_gb) + engine_builder.create_network(onnx_path=onnx_path) + dynamic_batch_sizes = None + dynamic_dimensions_in_use = all( + e is not None for e in [min_batch_size, opt_batch_size, max_batch_size] + ) + static_batch_size = None + if dynamic_dimensions_in_use: + dynamic_batch_sizes = min_batch_size, opt_batch_size, max_batch_size + else: + static_batch_size = engine_builder.get_static_batch_size_of_input() + if dynamic_batch_sizes: + print_to_console( + message=f"Compiling model with dynamic batch sizes: {dynamic_batch_sizes}, quantization: {precision}", + console=console, + ) + trt_config = TRTConfig( + dynamic_batch_size_min=dynamic_batch_sizes[0], + dynamic_batch_size_opt=dynamic_batch_sizes[1], + dynamic_batch_size_max=dynamic_batch_sizes[2], + ) + else: + print_to_console( + message=f"Compiling model with static batch size: {static_batch_size}, quantization: {precision}", + console=console, + ) + trt_config = TRTConfig( + static_batch_size=static_batch_size, + ) + if runtime_xray.l4t_version is not None: + machine_type = TRTMachineType.JETSON + machine_specs = JetsonMachineSpecsV1( + type="jetson-machine-specs-v1", + l4t_version=runtime_xray.l4t_version, + device_name=runtime_xray.jetson_type or "unknown", + driver_version=str(runtime_xray.driver_version), + ) + else: + machine_type = TRTMachineType.GPU_SERVER + machine_specs = GPUServerSpecsV1( + type="gpu-server-specs-v1", + driver_version=str(runtime_xray.driver_version), + os_version=runtime_xray.os_version, + ) + package_manifest = TRTModelPackageV1( + type="trt-model-package-v1", + backend_type="trt", + dynamic_batch_size=dynamic_batch_sizes is not None, + static_batch_size=static_batch_size, + min_batch_size=trt_config.dynamic_batch_size_min, + opt_batch_size=trt_config.dynamic_batch_size_opt, + max_batch_size=trt_config.dynamic_batch_size_max, + quantization=Quantization(precision), + cuda_device_type=runtime_xray.gpu_devices[0], + cuda_device_cc=str(runtime_xray.gpu_devices_cc[0]), + cuda_version=str(runtime_xray.cuda_version), + trt_version=str(runtime_xray.trt_version), + same_cc_compatible=same_compute_compatibility, + trt_forward_compatible=trt_version_compatible, + trt_lean_runtime_excluded=False, + machine_type=machine_type, + machine_specs=machine_specs, + ) + try: + # stating the registration, to see if package already sealed + _ = models_service_client.register_model_package( + model_id=model_id, + package_manifest=package_manifest.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + file_handles=file_handles_to_register, + model_features=registered_model_features, + ) + except RequestError as error: + if error.status_code == 409: + raise AlreadyCompiledError("Model package already compiled.") + logging.exception("Could not stat-create model package") + raise CompiledPackageRegistrationError( + f"Could not register model package - {error}" + ) from error + except Exception as error: + logging.exception("Error while registering model package") + raise CompiledPackageRegistrationError( + f"Could not register model package - {error}" + ) from error + compilation_features = { + "modelArchitecture": model_architecture, + "taskType": task_type, + "modelVariant": model_variant, + "modelInputSize": model_input_size, + "precision": precision, + "workspaceSizeGb": workspace_size_gb, + "trtForwardCompatible": trt_version_compatible, + "sameCCCompatible": same_compute_compatibility, + "dynamicBatchSizes": dynamic_batch_sizes, + "cudaDeviceType": runtime_xray.gpu_devices[0], + "trtVersion": ( + str(runtime_xray.trt_version) if runtime_xray.trt_version else None + ), + } + cache_manager = TimingCacheManager.init( + models_service_client=models_service_client, + compilation_features=compilation_features, + ) + engine_builder.set_timing_cache_manager(cache_manager=cache_manager) + engine_path = os.path.join(compilation_directory, "engine.plan") + engine_builder.create_engine( + engine_path=engine_path, + precision=precision, + input_size=model_input_size, + dynamic_batch_sizes=dynamic_batch_sizes, # type: ignore + trt_version_compatible=trt_version_compatible, + same_compute_compatibility=same_compute_compatibility, + ) + try: + # performing registration again, so that we have fresh upload URL + registration_result = models_service_client.register_model_package( + model_id=model_id, + package_manifest=package_manifest.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + file_handles=file_handles_to_register, + model_features=registered_model_features, + ) + return engine_path, trt_config, registration_result + except RequestError as error: + if error.status_code == 409: + raise AlreadyCompiledError("Model package already compiled.") + logging.exception("Could not stat-create model package") + raise CompiledPackageRegistrationError( + f"Could not register model package - {error}" + ) from error + except Exception as error: + logging.exception("Error while registering model package") + raise CompiledPackageRegistrationError( + f"Could not register model package - {error}" + ) from error + + +def register_model_package_artefacts( + registration_response: ModelPackageRegistrationResponse, + local_files_mapping: Dict[str, Tuple[str, str]], + models_service_client: ModelsServiceClient, +) -> None: + try: + confirmations = [] + for file_upload_spec in registration_response.file_upload_specs: + file_path, file_md5 = local_files_mapping[file_upload_spec.file_handle] + upload_file_to_cloud( + file_path=file_path, + url=file_upload_spec.signed_url_details.upload_url, + headers=file_upload_spec.signed_url_details.extension_headers, + ) + confirmations.append( + FileConfirmation( + file_handle=file_upload_spec.file_handle, + md5_hash=file_md5, + ) + ) + models_service_client.confirm_model_package_artefacts( + model_id=registration_response.model_id, + model_package_id=registration_response.model_package_id, + confirmations=confirmations, + seal_model_package=True, + ) + logging.info( + f"Registered package with id: {registration_response.model_package_id} " + f"for model: {registration_response.model_id}" + ) + except Exception as error: + logging.exception( + f"Could not register artefacts for package {registration_response.model_package_id}" + ) + raise CompiledPackageRegistrationError( + f"Could not register artefacts for package {registration_response.model_package_id}" + ) from error diff --git a/inference_cli/lib/enterprise/inference_compiler/core/compiler.py b/inference_cli/lib/enterprise/inference_compiler/core/compiler.py new file mode 100644 index 0000000000..8734a13305 --- /dev/null +++ b/inference_cli/lib/enterprise/inference_compiler/core/compiler.py @@ -0,0 +1,158 @@ +import tempfile +from functools import partial +from typing import Optional + +from rich.console import Console + +from inference_cli.lib.enterprise.inference_compiler.adapters.models_service import ( + ModelsServiceClient, +) +from inference_cli.lib.enterprise.inference_compiler.core.compilation_handlers.default import ( + compile_and_register_default_model, +) +from inference_cli.lib.enterprise.inference_compiler.core.entities import ( + CompilationConfig, +) +from inference_cli.lib.enterprise.inference_compiler.core.model_checks.default import ( + verify_auto_model, +) +from inference_cli.lib.enterprise.inference_compiler.errors import ( + ModelArchitectureNotSupportedError, +) +from inference_cli.lib.enterprise.inference_compiler.utils.logging import ( + print_model_info, + print_to_console, +) +from inference_models.weights_providers.core import get_model_from_provider +from inference_models.weights_providers.entities import ModelMetadata + +REGISTERED_COMPILATION_HANDLERS = { + "yolov8": partial( + compile_and_register_default_model, + compilation_config=CompilationConfig.for_yolo_models( + verify_model=verify_auto_model, + ), + ), + "yolov9": partial( + compile_and_register_default_model, + compilation_config=CompilationConfig.for_yolo_models( + verify_model=verify_auto_model, + ), + ), + "yolov5": partial( + compile_and_register_default_model, + compilation_config=CompilationConfig.for_yolo_models( + verify_model=verify_auto_model, + ), + ), + "yolov10": partial( + compile_and_register_default_model, + compilation_config=CompilationConfig.for_yolo_models( + verify_model=verify_auto_model, + ), + ), + "yolov11": partial( + compile_and_register_default_model, + compilation_config=CompilationConfig.for_yolo_models( + verify_model=verify_auto_model, + ), + ), + "yolov12": partial( + compile_and_register_default_model, + compilation_config=CompilationConfig.for_yolo_models( + verify_model=verify_auto_model, + ), + ), + "rfdetr": partial( + compile_and_register_default_model, + compilation_config=CompilationConfig.for_rfdetr_models( + verify_model=verify_auto_model, + ), + ), + "yolact": partial( + compile_and_register_default_model, + compilation_config=CompilationConfig.for_yolo_models( + verify_model=verify_auto_model, + ), + ), + "resnet": partial( + compile_and_register_default_model, + compilation_config=CompilationConfig.for_resnet_models( + verify_model=verify_auto_model, + ), + ), + "deep-lab-v3-plus": partial( + compile_and_register_default_model, + compilation_config=CompilationConfig.for_deep_lab_models( + verify_model=verify_auto_model, + ), + ), + "vit": partial( + compile_and_register_default_model, + compilation_config=CompilationConfig.for_vit_models( + verify_model=verify_auto_model, + ), + ), + "yolonas": partial( + compile_and_register_default_model, + compilation_config=CompilationConfig.for_yolo_models( + verify_model=verify_auto_model, + ), + ), +} + + +def compile_model( + model_id: str, + api_key: Optional[str] = None, + trt_forward_compatible: bool = False, + trt_same_cc_compatible: bool = False, + console: Optional[Console] = None, +) -> None: + print_to_console( + message="Inference Compiler", + justify="center", + style="bold green4", + console=console, + ) + models_service_client = ModelsServiceClient.init(api_key=api_key) + print_to_console(message="Retrieving Model metadata...", console=console) + model_metadata = get_model_from_provider( + model_id=model_id, + provider="roboflow", + api_key=api_key, + ) + print_model_info( + model_id=model_id, + model_metadata=model_metadata, + console=console, + ) + compile_and_register( + model_metadata=model_metadata, + models_service_client=models_service_client, + trt_forward_compatible=trt_forward_compatible, + trt_same_cc_compatible=trt_same_cc_compatible, + ) + + +def compile_and_register( + model_metadata: ModelMetadata, + models_service_client: ModelsServiceClient, + trt_forward_compatible: bool, + trt_same_cc_compatible: bool, + console: Optional[Console] = None, +) -> None: + print_to_console(message="Model compilation in progress...", console=console) + if model_metadata.model_architecture not in REGISTERED_COMPILATION_HANDLERS: + raise ModelArchitectureNotSupportedError( + f"Model architecture {model_metadata.model_architecture} not supported for compilation." + ) + with tempfile.TemporaryDirectory() as compilation_directory: + REGISTERED_COMPILATION_HANDLERS[model_metadata.model_architecture]( + model_metadata, + models_service_client, + compilation_directory, + trt_forward_compatible, + trt_same_cc_compatible, + console, + ) diff --git a/inference_cli/lib/enterprise/inference_compiler/core/entities.py b/inference_cli/lib/enterprise/inference_compiler/core/entities.py new file mode 100644 index 0000000000..242a13f765 --- /dev/null +++ b/inference_cli/lib/enterprise/inference_compiler/core/entities.py @@ -0,0 +1,164 @@ +from enum import Enum +from typing import Callable, Literal, Optional + +from pydantic import BaseModel, Field + +from inference_cli.lib.enterprise.inference_compiler.constants import ( + DEEP_LAB_MODELS_MAX_DYNAMIC_BATCH_SIZE, + DEEP_LAB_MODELS_MIN_DYNAMIC_BATCH_SIZE, + DEEP_LAB_MODELS_OPT_DYNAMIC_BATCH_SIZE, + DEEP_LAB_MODELS_WORKSPACE_SIZE, + RESNET_MODELS_MAX_DYNAMIC_BATCH_SIZE, + RESNET_MODELS_MIN_DYNAMIC_BATCH_SIZE, + RESNET_MODELS_OPT_DYNAMIC_BATCH_SIZE, + RESNET_MODELS_WORKSPACE_SIZE, + RFDETR_MODELS_MAX_DYNAMIC_BATCH_SIZE, + RFDETR_MODELS_MIN_DYNAMIC_BATCH_SIZE, + RFDETR_MODELS_OPT_DYNAMIC_BATCH_SIZE, + RFDETR_MODELS_WORKSPACE_SIZE, + VIT_MODELS_MAX_DYNAMIC_BATCH_SIZE, + VIT_MODELS_MIN_DYNAMIC_BATCH_SIZE, + VIT_MODELS_OPT_DYNAMIC_BATCH_SIZE, + VIT_MODELS_WORKSPACE_SIZE, + YOLO_MODELS_MAX_DYNAMIC_BATCH_SIZE, + YOLO_MODELS_MIN_DYNAMIC_BATCH_SIZE, + YOLO_MODELS_OPT_DYNAMIC_BATCH_SIZE, + YOLO_MODELS_WORKSPACE_SIZE, +) + + +class CompilationConfig(BaseModel): + workspace_size_gb: int + min_batch_size: int + opt_batch_size: int + max_batch_size: int + verify_model: Optional[Callable[[str], None]] = Field(default=None) + + @classmethod + def for_yolo_models( + cls, verify_model: Optional[Callable[[str], None]] = None + ) -> "CompilationConfig": + return cls( + workspace_size_gb=YOLO_MODELS_WORKSPACE_SIZE, + min_batch_size=YOLO_MODELS_MIN_DYNAMIC_BATCH_SIZE, + opt_batch_size=YOLO_MODELS_OPT_DYNAMIC_BATCH_SIZE, + max_batch_size=YOLO_MODELS_MAX_DYNAMIC_BATCH_SIZE, + verify_model=verify_model, + ) + + @classmethod + def for_rfdetr_models( + cls, verify_model: Optional[Callable[[str], None]] = None + ) -> "CompilationConfig": + return cls( + workspace_size_gb=RFDETR_MODELS_WORKSPACE_SIZE, + min_batch_size=RFDETR_MODELS_MIN_DYNAMIC_BATCH_SIZE, + opt_batch_size=RFDETR_MODELS_OPT_DYNAMIC_BATCH_SIZE, + max_batch_size=RFDETR_MODELS_MAX_DYNAMIC_BATCH_SIZE, + verify_model=verify_model, + ) + + @classmethod + def for_resnet_models( + cls, verify_model: Optional[Callable[[str], None]] = None + ) -> "CompilationConfig": + return cls( + workspace_size_gb=RESNET_MODELS_WORKSPACE_SIZE, + min_batch_size=RESNET_MODELS_MIN_DYNAMIC_BATCH_SIZE, + opt_batch_size=RESNET_MODELS_OPT_DYNAMIC_BATCH_SIZE, + max_batch_size=RESNET_MODELS_MAX_DYNAMIC_BATCH_SIZE, + verify_model=verify_model, + ) + + @classmethod + def for_vit_models( + cls, verify_model: Optional[Callable[[str], None]] = None + ) -> "CompilationConfig": + return cls( + workspace_size_gb=VIT_MODELS_WORKSPACE_SIZE, + min_batch_size=VIT_MODELS_MIN_DYNAMIC_BATCH_SIZE, + opt_batch_size=VIT_MODELS_OPT_DYNAMIC_BATCH_SIZE, + max_batch_size=VIT_MODELS_MAX_DYNAMIC_BATCH_SIZE, + verify_model=verify_model, + ) + + @classmethod + def for_deep_lab_models( + cls, verify_model: Optional[Callable[[str], None]] = None + ) -> "CompilationConfig": + return cls( + workspace_size_gb=DEEP_LAB_MODELS_WORKSPACE_SIZE, + min_batch_size=DEEP_LAB_MODELS_MIN_DYNAMIC_BATCH_SIZE, + opt_batch_size=DEEP_LAB_MODELS_OPT_DYNAMIC_BATCH_SIZE, + max_batch_size=DEEP_LAB_MODELS_MAX_DYNAMIC_BATCH_SIZE, + verify_model=verify_model, + ) + + +class GPUServerSpecsV1(BaseModel): + type: Literal["gpu-server-specs-v1"] = Field(default="gpu-server-specs-v1") + driver_version: str = Field(alias="driverVersion") + os_version: str = Field(alias="osVersion") + + class Config: + populate_by_name = True + + +class JetsonMachineSpecsV1(BaseModel): + type: Literal["jetson-machine-specs-v1"] = Field(default="jetson-machine-specs-v1") + l4t_version: str = Field(alias="l4tVersion") + device_name: str = Field(alias="deviceName") + driver_version: str = Field(alias="driverVersion") + + class Config: + populate_by_name = True + + +class Quantization(str, Enum): + FP32 = "fp32" + FP16 = "fp16" + BF16 = "bf16" + INT8 = "int8" + INT4 = "int4" + UNKNOWN = "unknown" + + +class TRTMachineType(str, Enum): + GPU_SERVER = "gpu-server" + JETSON = "jetson" + + +class TRTModelPackageV1(BaseModel): + type: Literal["trt-model-package-v1"] = Field(default="trt-model-package-v1") + backend_type: Literal["trt"] = Field(default="trt", alias="backendType") + dynamic_batch_size: bool = Field(alias="dynamicBatchSize") + static_batch_size: Optional[int] = Field(alias="staticBatchSize", default=None) + min_batch_size: Optional[int] = Field(alias="minBatchSize", default=None) + opt_batch_size: Optional[int] = Field(alias="optBatchSize", default=None) + max_batch_size: Optional[int] = Field(alias="maxBatchSize", default=None) + quantization: Quantization + cuda_device_type: str = Field(alias="cudaDeviceType") + cuda_device_cc: str = Field(alias="cudaDeviceCC") + cuda_version: str = Field(alias="cudaVersion") + trt_version: str = Field(alias="trtVersion") + same_cc_compatible: Optional[bool] = Field(alias="sameCCCompatible", default=None) + trt_forward_compatible: Optional[bool] = Field( + alias="trtForwardCompatible", default=None + ) + trt_lean_runtime_excluded: Optional[bool] = Field( + alias="trtLeanRuntimeExcluded", default=False + ) + machine_type: TRTMachineType = Field( + alias="machineType", default=TRTMachineType.GPU_SERVER + ) + machine_specs: GPUServerSpecsV1 = Field(alias="machineSpecs") + + class Config: + populate_by_name = True + + +class TRTConfig(BaseModel): + static_batch_size: Optional[int] = Field(default=None) + dynamic_batch_size_min: Optional[int] = Field(default=None) + dynamic_batch_size_opt: Optional[int] = Field(default=None) + dynamic_batch_size_max: Optional[int] = Field(default=None) diff --git a/inference_cli/lib/enterprise/inference_compiler/core/model_checks/__init__.py b/inference_cli/lib/enterprise/inference_compiler/core/model_checks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/inference_cli/lib/enterprise/inference_compiler/core/model_checks/default.py b/inference_cli/lib/enterprise/inference_compiler/core/model_checks/default.py new file mode 100644 index 0000000000..5c81b8545e --- /dev/null +++ b/inference_cli/lib/enterprise/inference_compiler/core/model_checks/default.py @@ -0,0 +1,21 @@ +import numpy as np + +from inference_cli.lib.enterprise.inference_compiler.errors import ( + ModelInferenceError, + ModelLoadingError, +) +from inference_models import AutoModel + + +def verify_auto_model(package_dir: str) -> None: + image = np.zeros((512, 513, 3), dtype=np.uint8) + try: + model = AutoModel.from_pretrained(model_id_or_path=package_dir) + except Exception as error: + raise ModelLoadingError("Could not load compiled model") from error + try: + _ = model(image) + except Exception as error: + raise ModelInferenceError( + "Could not perform inference from compiled model" + ) from error diff --git a/inference_cli/lib/enterprise/inference_compiler/errors.py b/inference_cli/lib/enterprise/inference_compiler/errors.py new file mode 100644 index 0000000000..d14bd2b300 --- /dev/null +++ b/inference_cli/lib/enterprise/inference_compiler/errors.py @@ -0,0 +1,83 @@ +from inference_cli.lib.exceptions import CLIError + + +class RemoteAPICallError(CLIError): + pass + + +class RetryError(RemoteAPICallError): + pass + + +class RequestError(RemoteAPICallError): + def __init__(self, message: str, status_code: int): + super().__init__(message) + self.status_code = status_code + + +class InferenceCompilerError(CLIError): + pass + + +class RuntimeConfigurationError(CLIError): + pass + + +class LackOfSourcePackageError(CLIError): + pass + + +class PackageNegotiationError(CLIError): + pass + + +class PackageDownloadError(CLIError): + pass + + +class CorruptedPackageError(CLIError): + pass + + +class TRTCompilerError(CLIError): + pass + + +class QuantizationNotSupportedError(TRTCompilerError): + pass + + +class InvalidNetworkInputsError(TRTCompilerError): + pass + + +class NetworkParsingError(TRTCompilerError): + pass + + +class TRTModelCompilationError(TRTCompilerError): + pass + + +class AlreadyCompiledError(CLIError): + pass + + +class CompiledPackageRegistrationError(CLIError): + pass + + +class ModelArchitectureNotSupportedError(CLIError): + pass + + +class ModelVerificationError(CLIError): + pass + + +class ModelLoadingError(ModelVerificationError): + pass + + +class ModelInferenceError(ModelVerificationError): + pass diff --git a/inference_cli/lib/enterprise/inference_compiler/utils/__init__.py b/inference_cli/lib/enterprise/inference_compiler/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/inference_cli/lib/enterprise/inference_compiler/utils/file_system.py b/inference_cli/lib/enterprise/inference_compiler/utils/file_system.py new file mode 100644 index 0000000000..d7768b0bca --- /dev/null +++ b/inference_cli/lib/enterprise/inference_compiler/utils/file_system.py @@ -0,0 +1,41 @@ +import hashlib +import json +from typing import Generator + + +def read_json(path: str) -> dict: + with open(path) as f: + return json.load(f) # type: ignore + + +def dump_json(path: str, content: dict) -> None: + with open(path, "w") as f: + json.dump(content, f) + + +def calculate_local_file_md5(file_path: str) -> str: + computed_hash = hashlib.md5() + for file_chunk in stream_file_bytes(path=file_path): + computed_hash.update(file_chunk) + return computed_hash.hexdigest() + + +def stream_file_bytes( + path: str, chunk_size: int = 16384 +) -> Generator[bytes, None, None]: + chunk_size = max(chunk_size, 1) + with open(path, "rb") as f: + chunk = f.read(chunk_size) + while chunk: + yield chunk + chunk = f.read(chunk_size) + + +def read_bytes(path: str) -> bytes: + with open(path, "rb") as f: + return f.read() + + +def write_bytes(path: str, content: bytes) -> None: + with open(path, "wb") as f: + f.write(content) diff --git a/inference_cli/lib/enterprise/inference_compiler/utils/http.py b/inference_cli/lib/enterprise/inference_compiler/utils/http.py new file mode 100644 index 0000000000..1108f5e5d4 --- /dev/null +++ b/inference_cli/lib/enterprise/inference_compiler/utils/http.py @@ -0,0 +1,52 @@ +import json +from typing import Dict + +import backoff +import requests # type: ignore +from requests import Response, Timeout + +from inference_cli.lib.enterprise.inference_compiler.constants import ( + HTTP_CODES_TO_RETRY, +) +from inference_cli.lib.enterprise.inference_compiler.errors import ( + RequestError, + RetryError, +) + + +@backoff.on_exception( + backoff.fibo, + exception=RetryError, + max_tries=3, + max_value=5, +) +def upload_file_to_cloud( + file_path: str, + url: str, + headers: Dict[str, str], +) -> None: + try: + with open(file_path, "rb") as f: + response = requests.put(url, headers=headers, data=f) + response.raise_for_status() + except (ConnectionError, Timeout, requests.exceptions.ConnectionError): + raise RetryError(f"Connectivity error") + handle_response_errors(response=response) + + +def handle_response_errors(response: Response) -> None: + if response.status_code in HTTP_CODES_TO_RETRY: + raise RetryError(f"Service returned {response.status_code}") + if response.status_code >= 400: + response_payload = get_error_response_payload(response=response) + raise RequestError( + message=f"RF API responded with status: {response.status_code} - error message: {response_payload}", + status_code=response.status_code, + ) + + +def get_error_response_payload(response: Response) -> str: + try: + return json.dumps(response.json(), indent=4) + except ValueError: + return response.text # type: ignore diff --git a/inference_cli/lib/enterprise/inference_compiler/utils/logging.py b/inference_cli/lib/enterprise/inference_compiler/utils/logging.py new file mode 100644 index 0000000000..1bf75b9f35 --- /dev/null +++ b/inference_cli/lib/enterprise/inference_compiler/utils/logging.py @@ -0,0 +1,44 @@ +from typing import Optional, Union + +from rich.console import Console, JustifyMethod +from rich.style import Style + +from inference_models.models.auto_loaders.presentation_utils import ( + render_table_with_model_overview, +) +from inference_models.weights_providers.entities import ModelMetadata + + +def print_to_console( + message: str, + console: Optional[Console] = None, + style: Optional[Union[str, Style]] = None, + justify: Optional[JustifyMethod] = None, +) -> None: + if console is None: + return None + console.print(message, style=style, justify=justify) + + +def print_model_info( + model_id: str, + model_metadata: ModelMetadata, + console: Optional[Console] = None, +) -> None: + if console is None: + return None + table = render_table_with_model_overview( + model_id=model_metadata.model_id, + requested_model_id=model_id, + model_architecture=model_metadata.model_architecture, + model_variant=model_metadata.model_variant, + task_type=model_metadata.task_type, + weights_provider="roboflow", + registered_packages=len(model_metadata.model_packages), + model_dependencies=( + model_metadata.model_dependencies + if hasattr(model_metadata, "model_dependencies") + else None + ), + ) + console.print(table) diff --git a/inference_cli/main.py b/inference_cli/main.py index 630f98fb64..4d0a83f45f 100644 --- a/inference_cli/main.py +++ b/inference_cli/main.py @@ -6,6 +6,7 @@ import inference_cli.lib from inference_cli.benchmark import benchmark_app from inference_cli.cloud import cloud_app +from inference_cli.lib.enterprise.core import enterprise_app from inference_cli.lib.roboflow_cloud.core import rf_cloud_app from inference_cli.server import server_app from inference_cli.workflows import workflows_app @@ -16,6 +17,7 @@ app.add_typer(benchmark_app, name="benchmark") app.add_typer(workflows_app, name="workflows") app.add_typer(rf_cloud_app, name="rf-cloud") +app.add_typer(enterprise_app, name="enterprise") def version_callback(value: bool):