Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion openff/nagl_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
load_nagl_model_directory_entry_points,
validate_nagl_model_path,
list_available_nagl_models,
get_models_by_type
get_models_by_type,
)
from openff.nagl_models._dynamic_fetch import get_model

Expand Down
64 changes: 41 additions & 23 deletions openff/nagl_models/_dynamic_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import urllib.request
import platformdirs
from packaging.version import Version

from openff.nagl_models import validate_nagl_model_path

RELEASES_URL = "https://api.github.com/repos/openforcefield/openff-nagl-models/releases"

Expand All @@ -27,6 +27,11 @@ class HashComparisonFailedException(Exception):
class UnableToParseDOIException(Exception):
"""Exception raised when a Zenodo DOI is unable to be parsed according to the expected pattern."""

class BadFileSuffixError(Exception):
"""Exception raised when a model file with an incorrect suffix is requested (this will happen a
lot with the current working of the ToolkitRegistry.call method, where things like "am1bcc" will
be requested from get_model due to toolkit precedence."""


def get_release_metadata() -> list[dict]:
return json.loads(urllib.request.urlopen(RELEASES_URL).read().decode("utf-8"))
Expand All @@ -40,9 +45,10 @@ def get_model(
) -> str:
"""
Return the path of a model as cached on disk, downloading if necessary. The lookup order of this implementation is:
1. Try to retrieve the file from the local cache
2. Try to fetch the file from a release of https://github.com/openforcefield/openff-nagl-models
3. Try to fetch the file from the DOI, if provided
1. Try to retrieve the file from the openff-nagl-models python package
Comment thread
j-wags marked this conversation as resolved.
Outdated
2. Try to retrieve the file from the local cache
3. Try to fetch the file from a release of https://github.com/openforcefield/openff-nagl-models
4. Try to fetch the file from the DOI, if provided

This method will raise an HashComparisonFailedException as soon as a hash mismatch is encountered. So if
there's a file with a matching name but a non-matching hash in the local cache, an exception will be raised
Expand Down Expand Up @@ -75,36 +81,48 @@ def get_model(
HashComparisonFailedException
FileNotFoundError
"""

if not(filename.endswith(".pt")):
raise BadFileSuffixError(f"NAGLToolkitWrapper does not recognize file path extension "
f"on {filename=}, expected '.pt' suffix")
Comment thread
j-wags marked this conversation as resolved.
Outdated
pathlib.Path(CACHE_DIR).mkdir(exist_ok=True)

cached_path = CACHE_DIR / filename

# See if the file has a known hash
if file_hash is None and filename in KNOWN_HASHES:
file_hash = KNOWN_HASHES[filename]

# See if it's available in the openff-nagl-models python package
try:
file_path = validate_nagl_model_path(filename)
assert_hash_equal(file_path, file_hash)
return file_path
except FileNotFoundError:
pass

# Then check if it's in the cache
cached_path = CACHE_DIR / filename
if cached_path.exists():
if file_hash:
assert_hash_equal(cached_path, file_hash)

return cached_path.as_posix()

release_metadata = get_release_metadata()

# tags with "v" prefix can't easily be sorted, but the result of passing through Version
# are not necessarily 1:1 with the metadata in the releases, keep both and map between
releases: dict[Version:str] = {
Version(release["tag_name"]): release for release in release_metadata
}

for version in reversed(sorted(releases)):
release = releases[version]
for file in release["assets"]:
if file["name"] == filename:
return _download_and_verify_file(
file["browser_download_url"], cached_path, file_hash
)

# release_metadata = get_release_metadata()
#
# # tags with "v" prefix can't easily be sorted, but the result of passing through Version
# # are not necessarily 1:1 with the metadata in the releases, keep both and map between
# releases: dict[Version:str] = {
# Version(release["tag_name"]): release for release in release_metadata
# }
#
# for version in reversed(sorted(releases)):
# release = releases[version]
# for file in release["assets"]:
# if file["name"] == filename:
# return _download_and_verify_file(
# file["browser_download_url"], cached_path, file_hash
# )

# Otherwise try to fetch from DOI
if doi:
try:
match = re.search(r"10\.(5072|5281)/zenodo\.([0-9]+)", doi)
Expand Down
Loading
Loading