diff --git a/inference/core/interfaces/http/handlers/workflows.py b/inference/core/interfaces/http/handlers/workflows.py index d174bbf835..f72c9ebaaa 100644 --- a/inference/core/interfaces/http/handlers/workflows.py +++ b/inference/core/interfaces/http/handlers/workflows.py @@ -45,6 +45,8 @@ ) from inference.core.workflows.prototypes.block import BlockAirGappedInfo +logger = logging.getLogger(__name__) + def handle_describe_workflows_blocks_request( dynamic_blocks_definitions: Optional[List[DynamicBlockDefinition]] = None, diff --git a/inference/core/models/inference_models_adapters.py b/inference/core/models/inference_models_adapters.py index 2b7855228e..e7f6c58252 100644 --- a/inference/core/models/inference_models_adapters.py +++ b/inference/core/models/inference_models_adapters.py @@ -62,6 +62,41 @@ ) from inference_models.models.base.types import PreprocessingMetadata + +def _resolve_cached_model_path(model_id: str) -> str: + """If the model is already in the inference-models local cache, return the + package directory path so ``AutoModel.from_pretrained`` can load directly + from disk without calling the Roboflow API. Returns the original + *model_id* unchanged when no local cache hit is found. + """ + try: + from inference.core.cache.air_gapped import ( + _get_inference_models_home, + _slugify_model_id, + ) + from inference.core.env import MODEL_CACHE_DIR + + slug = _slugify_model_id(model_id) + bases = [MODEL_CACHE_DIR] + inference_home = _get_inference_models_home() + if inference_home is not None and inference_home != MODEL_CACHE_DIR: + bases.append(inference_home) + + for base in bases: + import os + + slug_dir = os.path.join(base, "models-cache", slug) + if not os.path.isdir(slug_dir): + continue + for package_id in os.listdir(slug_dir): + package_dir = os.path.join(slug_dir, package_id) + if os.path.isfile(os.path.join(package_dir, "model_config.json")): + return package_dir + except Exception: + pass + return model_id + + DEFAULT_COLOR_PALETTE = [ "#A351FB", "#FF4040", @@ -108,7 +143,7 @@ def __init__(self, model_id: str, api_key: str = None, **kwargs): ) ) self._model: ObjectDetectionModel = AutoModel.from_pretrained( - model_id_or_path=model_id, + model_id_or_path=_resolve_cached_model_path(model_id), api_key=self.api_key, allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES, allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES, @@ -259,7 +294,7 @@ def __init__(self, model_id: str, api_key: str = None, **kwargs): ) ) self._model: InstanceSegmentationModel = AutoModel.from_pretrained( - model_id_or_path=model_id, + model_id_or_path=_resolve_cached_model_path(model_id), api_key=self.api_key, allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES, allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES, @@ -417,7 +452,7 @@ def __init__(self, model_id: str, api_key: str = None, **kwargs): ) ) self._model: KeyPointsDetectionModel = AutoModel.from_pretrained( - model_id_or_path=model_id, + model_id_or_path=_resolve_cached_model_path(model_id), api_key=self.api_key, allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES, allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES, @@ -627,7 +662,7 @@ def __init__(self, model_id: str, api_key: str = None, **kwargs): ) self._model: Union[ClassificationModel, MultiLabelClassificationModel] = ( AutoModel.from_pretrained( - model_id_or_path=model_id, + model_id_or_path=_resolve_cached_model_path(model_id), api_key=self.api_key, allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES, allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES, @@ -913,7 +948,7 @@ def __init__(self, model_id: str, api_key: str = None, **kwargs): ) ) self._model: SemanticSegmentationModel = AutoModel.from_pretrained( - model_id_or_path=model_id, + model_id_or_path=_resolve_cached_model_path(model_id), api_key=self.api_key, allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES, allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES, diff --git a/inference/core/registries/roboflow.py b/inference/core/registries/roboflow.py index 7c03ca6bc0..6e324b8ac7 100644 --- a/inference/core/registries/roboflow.py +++ b/inference/core/registries/roboflow.py @@ -299,21 +299,74 @@ def get_model_metadata_from_cache( def _get_model_metadata_from_cache( dataset_id: Union[DatasetID, ModelID], version_id: Optional[VersionID] ) -> Optional[Tuple[TaskType, ModelType]]: + # Layout 1: traditional model_type.json model_type_cache_path = construct_model_type_cache_path( dataset_id=dataset_id, version_id=version_id ) - if not os.path.isfile(model_type_cache_path): - return None - try: - model_metadata = read_json(path=model_type_cache_path) - if model_metadata_content_is_invalid(content=model_metadata): - return None - return model_metadata[PROJECT_TASK_TYPE_KEY], model_metadata[MODEL_TYPE_KEY] - except ValueError as e: - logger.warning( - f"Could not load model description from cache under path: {model_type_cache_path} - decoding issue: {e}." - ) - return None + if os.path.isfile(model_type_cache_path): + try: + model_metadata = read_json(path=model_type_cache_path) + if not model_metadata_content_is_invalid(content=model_metadata): + return ( + model_metadata[PROJECT_TASK_TYPE_KEY], + model_metadata[MODEL_TYPE_KEY], + ) + except ValueError as e: + logger.warning( + f"Could not load model description from cache under path: " + f"{model_type_cache_path} - decoding issue: {e}." + ) + + # Layout 2: inference-models model_config.json + model_id = f"{dataset_id}/{version_id}" if version_id else dataset_id + result = _get_model_metadata_from_inference_models_cache(model_id) + if result is not None: + return result + + return None + + +def _get_model_metadata_from_inference_models_cache( + model_id: str, +) -> Optional[Tuple[TaskType, ModelType]]: + """Check the inference-models cache layout for model metadata. + + Looks for ``model_config.json`` under + ``{base}/models-cache/{slug}/{package_id}/model_config.json`` + where *base* is ``MODEL_CACHE_DIR`` and optionally ``INFERENCE_HOME``. + """ + from inference.core.cache.air_gapped import ( + _get_inference_models_home, + _slugify_model_id, + ) + + slug = _slugify_model_id(model_id) + + bases = [MODEL_CACHE_DIR] + inference_home = _get_inference_models_home() + if inference_home is not None and inference_home != MODEL_CACHE_DIR: + bases.append(inference_home) + + for base in bases: + slug_dir = os.path.join(base, "models-cache", slug) + if not os.path.isdir(slug_dir): + continue + for package_id in os.listdir(slug_dir): + config_path = os.path.join(slug_dir, package_id, "model_config.json") + if not os.path.isfile(config_path): + continue + try: + metadata = read_json(path=config_path) + except ValueError: + continue + if not isinstance(metadata, dict): + continue + task_type = metadata.get("task_type", "") + model_arch = metadata.get("model_architecture", "") + if task_type and model_arch: + return task_type, model_arch + + return None def model_metadata_content_is_invalid(content: Optional[Union[list, dict]]) -> bool: diff --git a/inference_models/inference_models/models/auto_loaders/core.py b/inference_models/inference_models/models/auto_loaders/core.py index 53d279d137..0ab0c97198 100644 --- a/inference_models/inference_models/models/auto_loaders/core.py +++ b/inference_models/inference_models/models/auto_loaders/core.py @@ -1292,6 +1292,7 @@ def initialize_model( task_type=task_type, backend_type=model_package.backend, file_lock_acquire_timeout=model_download_file_lock_acquire_timeout, + model_id=model_id, on_file_created=on_file_created, ) resolved_files = set(shared_files_mapping.values()) @@ -1392,6 +1393,7 @@ def dump_model_config_for_offline_use( task_type: TaskType, backend_type: Optional[BackendType], file_lock_acquire_timeout: int, + model_id: Optional[str] = None, on_file_created: Optional[Callable[[str], None]] = None, ) -> None: if os.path.exists(config_path): @@ -1400,14 +1402,17 @@ def dump_model_config_for_offline_use( return None target_file_dir, target_file_name = os.path.split(config_path) lock_path = os.path.join(target_file_dir, f".{target_file_name}.lock") + content = { + "model_architecture": model_architecture, + "task_type": task_type, + "backend_type": backend_type, + } + if model_id is not None: + content["model_id"] = model_id with FileLock(lock_path, timeout=file_lock_acquire_timeout): dump_json( path=config_path, - content={ - "model_architecture": model_architecture, - "task_type": task_type, - "backend_type": backend_type, - }, + content=content, ) if on_file_created: on_file_created(config_path)