-
Notifications
You must be signed in to change notification settings - Fork 252
Add offline/air-gapped model resolution for inference-models cache #2187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
94c9904
835af61
8b77ac4
d00a018
82e5d4a
115b2bb
2a52609
459f2b0
30f77f8
7bbba72
e3e29ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,40 @@ | |
| ) | ||
| from inference_models.models.base.types import PreprocessingMetadata | ||
|
|
||
| def _resolve_cached_model_path(model_id: str) -> str: | ||
| """If the model is already in the inference-models local cache, return the | ||
| package directory path so ``AutoModel.from_pretrained`` can load directly | ||
| from disk without calling the Roboflow API. Returns the original | ||
| *model_id* unchanged when no local cache hit is found. | ||
| """ | ||
| try: | ||
| from inference.core.cache.air_gapped import ( | ||
| _get_inference_models_home, | ||
| _slugify_model_id, | ||
| ) | ||
| from inference.core.env import MODEL_CACHE_DIR | ||
|
|
||
| slug = _slugify_model_id(model_id) | ||
| bases = [MODEL_CACHE_DIR] | ||
| inference_home = _get_inference_models_home() | ||
| if inference_home is not None and inference_home != MODEL_CACHE_DIR: | ||
| bases.append(inference_home) | ||
|
|
||
| for base in bases: | ||
| import os | ||
|
|
||
| slug_dir = os.path.join(base, "models-cache", slug) | ||
| if not os.path.isdir(slug_dir): | ||
| continue | ||
| for package_id in os.listdir(slug_dir): | ||
| package_dir = os.path.join(slug_dir, package_id) | ||
| if os.path.isfile(os.path.join(package_dir, "model_config.json")): | ||
| return package_dir | ||
| except Exception: | ||
| pass | ||
| return model_id | ||
|
|
||
|
|
||
| DEFAULT_COLOR_PALETTE = [ | ||
| "#A351FB", | ||
| "#FF4040", | ||
|
|
@@ -108,7 +142,7 @@ def __init__(self, model_id: str, api_key: str = None, **kwargs): | |
| ) | ||
| ) | ||
| self._model: ObjectDetectionModel = AutoModel.from_pretrained( | ||
| model_id_or_path=model_id, | ||
| model_id_or_path=_resolve_cached_model_path(model_id), | ||
| api_key=self.api_key, | ||
| allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES, | ||
| allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES, | ||
|
|
@@ -259,7 +293,7 @@ def __init__(self, model_id: str, api_key: str = None, **kwargs): | |
| ) | ||
| ) | ||
| self._model: InstanceSegmentationModel = AutoModel.from_pretrained( | ||
| model_id_or_path=model_id, | ||
| model_id_or_path=_resolve_cached_model_path(model_id), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why this is needed?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than manually loading from cache, we can extend auto loader TTL, and we will need to add a flag to throw an error rather than loading from API when files are missing or integrity is not verified. This will be a trivial change to AutoLoader. |
||
| api_key=self.api_key, | ||
| allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES, | ||
| allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES, | ||
|
|
@@ -417,7 +451,7 @@ def __init__(self, model_id: str, api_key: str = None, **kwargs): | |
| ) | ||
| ) | ||
| self._model: KeyPointsDetectionModel = AutoModel.from_pretrained( | ||
| model_id_or_path=model_id, | ||
| model_id_or_path=_resolve_cached_model_path(model_id), | ||
| api_key=self.api_key, | ||
| allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES, | ||
| allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES, | ||
|
|
@@ -627,7 +661,7 @@ def __init__(self, model_id: str, api_key: str = None, **kwargs): | |
| ) | ||
| self._model: Union[ClassificationModel, MultiLabelClassificationModel] = ( | ||
| AutoModel.from_pretrained( | ||
| model_id_or_path=model_id, | ||
| model_id_or_path=_resolve_cached_model_path(model_id), | ||
| api_key=self.api_key, | ||
| allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES, | ||
| allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES, | ||
|
|
@@ -913,7 +947,7 @@ def __init__(self, model_id: str, api_key: str = None, **kwargs): | |
| ) | ||
| ) | ||
| self._model: SemanticSegmentationModel = AutoModel.from_pretrained( | ||
| model_id_or_path=model_id, | ||
| model_id_or_path=_resolve_cached_model_path(model_id), | ||
| api_key=self.api_key, | ||
| allow_untrusted_packages=ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES, | ||
| allow_direct_local_storage_loading=ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure, probably not