Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 107 additions & 15 deletions openff/nagl_models/_dynamic_fetch.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import functools
import hashlib
import json
import re
import pathlib
import urllib.request

import platformdirs
from packaging.version import Version

Expand All @@ -19,19 +19,73 @@

CACHE_DIR = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS"


class HashComparisonFailedException(Exception):
"""Exception raised when a NAGL file being loaded fails a comparison to a known or user-provided hash."""


class UnableToParseDOIException(Exception):
"""Exception raised when a Zenodo DOI is unable to be parsed according to the expected pattern."""


def get_release_metadata() -> list[dict]:
return json.loads(urllib.request.urlopen(RELEASES_URL).read().decode("utf-8"))


@functools.lru_cache()
def get_model(filename: str) -> str:
"""Return the path of a model as cached on disk, downloading if necessary."""
def get_model(
filename: str,
doi: None | str = None,
file_hash: None | str = None,
) -> str:
"""
Return the path of a model as cached on disk, downloading if necessary. The lookup order of this implementation is:
1. Try to retrieve the file from the local cache
2. Try to fetch the file from a release of https://github.com/openforcefield/openff-nagl-models
3. Try to fetch the file from the DOI, if provided

This method will raise an HashComparisonFailedException as soon as a hash mismatch is encountered. So if
there's a file with a matching name but a non-matching hash in the local cache, an exception will be raised
immediately, even if a file with a matching name that WOULD satisfy the hash check exists in release
metadata or at a provided Zenodo DOI.

Parameters
----------
filename
The name of the file to search for.
doi
The Zenodo DOI to use as a backup location for fetching the model file if it's not found in the local cache
or in the
[release metadata of an openff-nagl-models release](https://github.com/openforcefield/openff-nagl-models/releases)
on GitHub. For example: "10.5072/zenodo.278300"
file_hash
The sha256 hash of the model file to verify the correct contents. Hash checks are automatically performed
on some OpenFF-released NAGL models. But if the model isn't released by OpenFF and this argument is
not provided or has a value of `None`, then no hash check is performed. Raises HashComparisonFailedException
if unsuccessful. If a user provides a hash value here that disagrees with the known hash for the same file
name, the user-provided hash takes precedence.

Returns
-------
str
The path to the file if it was found. If the file wasn't found then a FileNotFoundError is rasied.

Raises
------
HashComparisonFailedException
FileNotFoundError
"""

pathlib.Path(CACHE_DIR).mkdir(exist_ok=True)

cached_path = CACHE_DIR / filename

if file_hash is None and filename in KNOWN_HASHES:
file_hash = KNOWN_HASHES[filename]

if cached_path.exists():
assert _get_sha256(cached_path) == KNOWN_HASHES[filename]
if file_hash:
assert_hash_equal(cached_path, file_hash)

return cached_path.as_posix()

Expand All @@ -47,25 +101,63 @@ def get_model(filename: str) -> str:
release = releases[version]
for file in release["assets"]:
if file["name"] == filename:
path_to_file, _ = urllib.request.urlretrieve(
url=file["browser_download_url"],
filename=cached_path.as_posix(),
)

assert cached_path.exists()
assert path_to_file == cached_path.as_posix()

assert _get_sha256(cached_path) == KNOWN_HASHES[filename], (
f"Hash mismatch for {filename}"
return _download_and_verify_file(
file["browser_download_url"], cached_path, file_hash
)

return cached_path.as_posix()
if doi:
try:
match = re.search(r"10\.(5072|5281)/zenodo\.([0-9]+)", doi)
if not match:
raise IndexError
prefix, zenodo_id = match.groups()
except (IndexError, AttributeError):
raise UnableToParseDOIException(
f"Unable to parse Zenodo DOI {doi}. DOI values are expected to look "
f"like '10.5281/zenodo.278300' (production) or '10.5072/zenodo.278300' (sandbox)"
)

if prefix == "5072":
file_url = (
f"https://sandbox.zenodo.org/api/records/{zenodo_id}/files/{filename}"
)
else:
file_url = f"https://zenodo.org/api/records/{zenodo_id}/files/{filename}"

try:
return _download_and_verify_file(file_url, cached_path, file_hash)
except urllib.error.HTTPError:
raise FileNotFoundError(f"No file at {file_url}")

raise FileNotFoundError(
f"Could not find asset with name '{filename}' in any release"
)


def assert_hash_equal(cached_path, expected_hash):
actual_hash = _get_sha256(cached_path)
if actual_hash != expected_hash:
raise HashComparisonFailedException(
f"NAGL model file hash check failed. Expected hash is "
f"{expected_hash} but actual hash is {actual_hash}"
)


def _download_and_verify_file(
url: str, cached_path: pathlib.Path, file_hash: None | str = None
) -> str:
"""Download a file from URL to cached_path and optionally verify its hash."""
path_to_file, _ = urllib.request.urlretrieve(url, filename=cached_path.as_posix())

assert cached_path.exists()
assert path_to_file == cached_path.as_posix()

if file_hash:
assert_hash_equal(cached_path, file_hash)

return cached_path.as_posix()


def _get_sha256(filename: str) -> str:
"""Get the SHA256 hash of a file from its path, assuming it's a binary file like a PyTorch model."""
hash = hashlib.sha256()
Expand Down
16 changes: 8 additions & 8 deletions openff/nagl_models/openff_nagl_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This module only contains the function that will be the entry point that
will be used to find the model files.
"""

import importlib.resources
import os
import pathlib
Expand Down Expand Up @@ -166,7 +167,8 @@ def list_available_nagl_models() -> list[pathlib.Path]:
# look for all .pt files in the cache directory, but only those that are
# expected to also be found in release assets
cached_paths = [
cached_file for cached_file in CACHE_DIR.rglob("*.pt")
cached_file
for cached_file in CACHE_DIR.rglob("*.pt")
if cached_file.name in KNOWN_HASHES
]

Expand Down Expand Up @@ -205,12 +207,12 @@ def get_models_by_type(
--------

Getting the latest pre-release model for am1bcc::

>>> from openff.nagl_models.openff_nagl_models import get_models_by_type
>>> get_models_by_type(model_type="am1bcc")
[PosixPath('/.../openff-nagl-models/openff/nagl_models/models/am1bcc/openff-gnn-am1bcc-0.0.1-alpha.1.pt'),
PosixPath('/.../openff-nagl-models/openff/nagl_models/models/am1bcc/openff-gnn-am1bcc-0.1.0-rc.1.pt')]

"""
from packaging.version import Version

Expand All @@ -221,14 +223,12 @@ def get_models_by_type(
"If you are using a custom model, "
"please manually specify the path to the model file."
)

model_files = pathlib.Path(base_dir).glob("*.pt")

# assume everything follows the openff-gnn-<model_type>-<version>.pt format
n_name = len(f"openff-gnn-{model_type}-")
versions_to_paths = {
Version(f.stem[n_name:]): f for f in model_files
}
versions_to_paths = {Version(f.stem[n_name:]): f for f in model_files}
versions = sorted(versions_to_paths.keys())
if production_only:
versions = [v for v in versions if not v.is_prerelease]
Expand Down
81 changes: 76 additions & 5 deletions openff/nagl_models/tests/test_dynamic_fetch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
import pathlib
import shutil
import urllib.request
Expand All @@ -9,7 +10,11 @@

import openff.nagl_models._dynamic_fetch
from openff.nagl_models import __file__ as root
from openff.nagl_models._dynamic_fetch import get_model
from openff.nagl_models._dynamic_fetch import (
get_model,
HashComparisonFailedException,
UnableToParseDOIException,
)


def mocked_urlretrieve(url, filename):
Expand Down Expand Up @@ -59,11 +64,27 @@ def test_get_known_models(monkeypatch, known_model):
assert "OPENFF_NAGL_MODELS" in get_model(known_model)


def test_access_internet_with_empty_cache():
cache_path = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS"
@pytest.fixture
def hide_cache():
cache_dir = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS"
alt_dir = str(cache_dir) + "_temp"

if os.path.exists(alt_dir):
raise FileExistsError(f"Temporary directory already exists: {alt_dir}")

if os.path.exists(cache_dir):
shutil.move(cache_dir, alt_dir)

yield

if cache_path.exists():
shutil.rmtree(cache_path)
if os.path.exists(alt_dir):
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
shutil.move(alt_dir, cache_dir)


def test_access_internet_with_empty_cache(hide_cache):
cache_path = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS"

disable_socket()

Expand Down Expand Up @@ -147,3 +168,53 @@ def test_all_models_loadable(model, monkeypatch):
)

GNNModel.load(get_model(model), eval_mode=True)


def test_get_model_by_doi_and_hash(hide_cache):
# This test uses a Zenodo sandbox DOI (10.5072 prefix) and the corresponding
# SHA256 hash of the test file uploaded to that sandbox record
get_model(
"my_favorite_model.pt",
doi="10.5072/zenodo.278300",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This record must be sand-box only? This is my first Google result, which seems unlikely to be what you actually want to point to: https://zenodo.org/records/14335473

A comment or note about where this lives and how the hash was generated would be useful for future developers, I don't think anything else would be necessary here

file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81",
)


def test_get_model_by_doi_no_hash(hide_cache):
get_model("my_favorite_model.pt", doi="10.5072/zenodo.278300")


def test_get_model_hash_comparison_fails():
with pytest.raises(HashComparisonFailedException):
get_model(
"my_favorite_model.pt",
doi="10.5072/zenodo.278300",
file_hash="wrong_hash",
)


def test_user_provided_hash_conflicts_with_known_hash():
with pytest.raises(HashComparisonFailedException):
get_model("openff-gnn-am1bcc-0.1.0-rc.3.pt", file_hash="wrong_hash")


def test_malformed_doi(monkeypatch, hide_cache):
with monkeypatch.context() as m:
m.setattr(
urllib.request,
"urlretrieve",
mocked_urlretrieve,
)
m.setattr(
openff.nagl_models._dynamic_fetch,
"get_release_metadata",
mocked_get_release_metadata,
)

with pytest.raises(UnableToParseDOIException):
get_model("my_favorite_model.pt", doi="zenodo.278300")


def test_no_matching_file_at_doi():
with pytest.raises(FileNotFoundError, match="sandbox.zenodo"):
get_model("file_that_doesnt_exist.pt", doi="10.5072/zenodo.278300")
Loading