Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@ The rules for this file:
* accompany each entry with github issue/PR number (Issue #xyz)
-->

## Current development

### Authors
<!-- GitHub usernames of contributors to this release -->
- @lilyminium
- @Yoshanuikabundi
- @mattwthompson
- @j-wags
- @jaclark5 (assisted with debugging caching issues)

### New features
- Added fetching by DOI, hash verification, and caching. (#44, #61, #62)

## v0.3.0 - 2024-07-29

### Authors
Expand Down
2 changes: 1 addition & 1 deletion openff/nagl_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
load_nagl_model_directory_entry_points,
validate_nagl_model_path,
list_available_nagl_models,
get_models_by_type
get_models_by_type,
)
from openff.nagl_models._dynamic_fetch import get_model

Expand Down
50 changes: 25 additions & 25 deletions openff/nagl_models/_dynamic_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import urllib.request
import platformdirs
from packaging.version import Version

from openff.nagl_models import validate_nagl_model_path

RELEASES_URL = "https://api.github.com/repos/openforcefield/openff-nagl-models/releases"

Expand All @@ -27,22 +27,26 @@ class HashComparisonFailedException(Exception):
class UnableToParseDOIException(Exception):
"""Exception raised when a Zenodo DOI is unable to be parsed according to the expected pattern."""

class BadFileSuffixError(Exception):
"""Exception raised when a model file with an incorrect suffix is requested (this will happen a
lot with the current working of the ToolkitRegistry.call method, where things like "am1bcc" will
be requested from get_model due to toolkit precedence."""


def get_release_metadata() -> list[dict]:
return json.loads(urllib.request.urlopen(RELEASES_URL).read().decode("utf-8"))


@functools.lru_cache()
def get_model(
filename: str,
doi: None | str = None,
file_hash: None | str = None,
) -> str:
"""
Return the path of a model as cached on disk, downloading if necessary. The lookup order of this implementation is:
1. Try to retrieve the file from the local cache
2. Try to fetch the file from a release of https://github.com/openforcefield/openff-nagl-models
3. Try to fetch the file from the DOI, if provided
1. Try to retrieve the file from the installed `openff-nagl-models` python package on disk
2. Try to retrieve the file from the local cache
3. Try to fetch the file from the Zenodo DOI, if provided

This method will raise an HashComparisonFailedException as soon as a hash mismatch is encountered. So if
there's a file with a matching name but a non-matching hash in the local cache, an exception will be raised
Expand All @@ -68,43 +72,39 @@ def get_model(
Returns
-------
str
The path to the file if it was found. If the file wasn't found then a FileNotFoundError is rasied.
The path to the file if it was found. If the file wasn't found then a FileNotFoundError is raised.

Raises
------
HashComparisonFailedException
FileNotFoundError
"""

if not(filename.endswith(".pt")):
raise BadFileSuffixError(f"OpenFF NAGL models are based on PyTorch files and expect a `.pt` suffix. Found an unrecognized file path extension "
f"on {filename=}")
pathlib.Path(CACHE_DIR).mkdir(exist_ok=True)

cached_path = CACHE_DIR / filename

# See if the file has a known hash
if file_hash is None and filename in KNOWN_HASHES:
file_hash = KNOWN_HASHES[filename]

# See if it's available in the openff-nagl-models python package
try:
file_path = validate_nagl_model_path(filename)
assert_hash_equal(file_path, file_hash)
return file_path
except FileNotFoundError:
pass

# Then check if it's in the cache
cached_path = CACHE_DIR / filename
if cached_path.exists():
if file_hash:
assert_hash_equal(cached_path, file_hash)

return cached_path.as_posix()

release_metadata = get_release_metadata()

# tags with "v" prefix can't easily be sorted, but the result of passing through Version
# are not necessarily 1:1 with the metadata in the releases, keep both and map between
releases: dict[Version:str] = {
Version(release["tag_name"]): release for release in release_metadata
}

for version in reversed(sorted(releases)):
release = releases[version]
for file in release["assets"]:
if file["name"] == filename:
return _download_and_verify_file(
file["browser_download_url"], cached_path, file_hash
)

# Otherwise try to fetch from DOI
if doi:
try:
match = re.search(r"10\.(5072|5281)/zenodo\.([0-9]+)", doi)
Expand Down
Loading
Loading