Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion openff/nagl_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
load_nagl_model_directory_entry_points,
validate_nagl_model_path,
list_available_nagl_models,
get_models_by_type
get_models_by_type,
)
from openff.nagl_models._dynamic_fetch import get_model

Expand Down
48 changes: 24 additions & 24 deletions openff/nagl_models/_dynamic_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import urllib.request
import platformdirs
from packaging.version import Version

from openff.nagl_models import validate_nagl_model_path

RELEASES_URL = "https://api.github.com/repos/openforcefield/openff-nagl-models/releases"

Expand All @@ -27,21 +27,25 @@ class HashComparisonFailedException(Exception):
class UnableToParseDOIException(Exception):
"""Exception raised when a Zenodo DOI is unable to be parsed according to the expected pattern."""

class BadFileSuffixError(Exception):
"""Exception raised when a model file with an incorrect suffix is requested (this will happen a
lot with the current working of the ToolkitRegistry.call method, where things like "am1bcc" will
be requested from get_model due to toolkit precedence."""


def get_release_metadata() -> list[dict]:
return json.loads(urllib.request.urlopen(RELEASES_URL).read().decode("utf-8"))


@functools.lru_cache()
def get_model(
filename: str,
doi: None | str = None,
file_hash: None | str = None,
) -> str:
"""
Return the path of a model as cached on disk, downloading if necessary. The lookup order of this implementation is:
1. Try to retrieve the file from the local cache
2. Try to fetch the file from a release of https://github.com/openforcefield/openff-nagl-models
1. Try to retrieve the file from the openff-nagl-models python package
Comment thread
j-wags marked this conversation as resolved.
Outdated
2. Try to retrieve the file from the local cache
3. Try to fetch the file from the DOI, if provided
Comment thread
j-wags marked this conversation as resolved.
Outdated

This method will raise an HashComparisonFailedException as soon as a hash mismatch is encountered. So if
Expand All @@ -68,43 +72,39 @@ def get_model(
Returns
-------
str
The path to the file if it was found. If the file wasn't found then a FileNotFoundError is rasied.
The path to the file if it was found. If the file wasn't found then a FileNotFoundError is raised.

Raises
------
HashComparisonFailedException
FileNotFoundError
"""

if not(filename.endswith(".pt")):
raise BadFileSuffixError(f"NAGLToolkitWrapper does not recognize file path extension "
f"on {filename=}, expected '.pt' suffix")
Comment thread
j-wags marked this conversation as resolved.
Outdated
pathlib.Path(CACHE_DIR).mkdir(exist_ok=True)

cached_path = CACHE_DIR / filename

# See if the file has a known hash
if file_hash is None and filename in KNOWN_HASHES:
file_hash = KNOWN_HASHES[filename]

# See if it's available in the openff-nagl-models python package
try:
file_path = validate_nagl_model_path(filename)
assert_hash_equal(file_path, file_hash)
return file_path
except FileNotFoundError:
pass

# Then check if it's in the cache
cached_path = CACHE_DIR / filename
if cached_path.exists():
if file_hash:
assert_hash_equal(cached_path, file_hash)

return cached_path.as_posix()

release_metadata = get_release_metadata()

# tags with "v" prefix can't easily be sorted, but the result of passing through Version
# are not necessarily 1:1 with the metadata in the releases, keep both and map between
releases: dict[Version:str] = {
Version(release["tag_name"]): release for release in release_metadata
}

for version in reversed(sorted(releases)):
release = releases[version]
for file in release["assets"]:
if file["name"] == filename:
return _download_and_verify_file(
file["browser_download_url"], cached_path, file_hash
)

# Otherwise try to fetch from DOI
if doi:
try:
match = re.search(r"10\.(5072|5281)/zenodo\.([0-9]+)", doi)
Expand Down
209 changes: 79 additions & 130 deletions openff/nagl_models/tests/test_dynamic_fetch.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import json
import os
import pathlib
import shutil
import urllib.request

import platformdirs
import pytest
from pytest_socket import SocketBlockedError, disable_socket

import openff.nagl_models._dynamic_fetch
from openff.nagl_models import __file__ as root
from openff.nagl_models._dynamic_fetch import (
get_model,
HashComparisonFailedException,
UnableToParseDOIException,
BadFileSuffixError,
)


Expand All @@ -29,41 +26,6 @@ def mocked_urlretrieve(url, filename):
return new, None


def mocked_get_release_metadata():
# can regenerate this file with
# $ wget https://api.github.com/repos/openforcefield/openff-nagl-models/releases
return json.loads(
open(pathlib.Path(root).parent / "tests/data/releases.json").read()
)


@pytest.mark.parametrize(
"known_model",
[
"openff-gnn-am1bcc-0.0.1-alpha.1.pt",
"openff-gnn-am1bcc-0.1.0-rc.1.pt",
"openff-gnn-am1bcc-0.1.0-rc.2.pt",
"openff-gnn-am1bcc-0.1.0-rc.3.pt",
],
)
def test_get_known_models(monkeypatch, known_model):
with monkeypatch.context() as m:
m.setattr(
urllib.request,
"urlretrieve",
mocked_urlretrieve,
)
m.setattr(
openff.nagl_models._dynamic_fetch,
"get_release_metadata",
mocked_get_release_metadata,
)

assert get_model(known_model).endswith(known_model)

assert "OPENFF_NAGL_MODELS" in get_model(known_model)


@pytest.fixture
def hide_cache():
cache_dir = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS"
Expand All @@ -83,60 +45,94 @@ def hide_cache():
shutil.move(alt_dir, cache_dir)


def test_access_internet_with_empty_cache(hide_cache):
cache_path = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS"
def test_zenodo_fetching_and_caching(hide_cache):
"""
All of the tests that rely on remote fetching into the cache
and checking whether something is in the cache need to be run in
serial, otherwise they'll interfere with each other, so they're
all consolidated into this one test.
"""

# This test uses a Zenodo sandbox DOI (10.5072 prefix) and the corresponding
# SHA256 hash of the test file "my_favorite_model.pt" (which is a copy of
# openff-gnn-am1bcc-0.1.0-rc.3.pt) uploaded to that sandbox record

from pytest_socket import SocketBlockedError, disable_socket, enable_socket

disable_socket()

# would be nice to test the FileNotFoundError, but much more difficult to get that
# particular network failure vs. checking that the network is accessed at all
# Ensure that the cache is hidden,
with pytest.raises(FileNotFoundError):

get_model(
"my_favorite_model.pt",
)

# Ensure that trying to fetch a
# model fails due to lack of internet access
with pytest.raises(
SocketBlockedError,
):
get_model.cache_clear()
get_model(
"my_favorite_model.pt",
doi="10.5072/zenodo.278300",
file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81",
)

get_model("openff-gnn-am1bcc-0.1.0-rc.3.pt")
# Ensure that the file can actually be fetched
enable_socket()
get_model(
"my_favorite_model.pt",
doi="10.5072/zenodo.278300",
)

# Ensure that, once fetched, the file can be gotten without accessing the internet.
disable_socket()
# Ensure that cached files can be accessed when no optional arguments are provided
get_model(
"my_favorite_model.pt",
)
Comment on lines +88 to +90
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we be extra explicit about where the file is and isn't?

  • Assert that this file is not in the installed Python package (I'd be surprised if it ever was) - feel free to do this anywhere in this test, I want to make sure that the checks into the cache and/or internet actually go there and aren't broken if we accidentally ship our "favorite" little model
  • Assert that this file is present in the cache location

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!


def test_file_exists_in_cache_without_internet(monkeypatch):
# since tests can run in different orders, make sure the file exists already
with monkeypatch.context() as m:
m.setattr(
urllib.request,
"urlretrieve",
mocked_urlretrieve,
)
m.setattr(
openff.nagl_models._dynamic_fetch,
"get_release_metadata",
mocked_get_release_metadata,
)
# Ensure that cached files can be accessed when only doi is provided
Comment thread
j-wags marked this conversation as resolved.
Outdated
get_model(
"my_favorite_model.pt",
doi="10.5072/zenodo.278300",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not blocking) I can't figure out a good way (or reason to) test this, but I was a little surprised that an incorrect (but correctly-formatted) DOI could be passed here and it might not necessarily match the contents on Zenodo

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I somewhat addressed this in a spec clarification openforcefield/standards@2b6ef83

)

assert get_model("openff-gnn-am1bcc-0.1.0-rc.3.pt")
# Ensure that cached files can be accessed when all optional arguments are provided
get_model(
"my_favorite_model.pt",
doi="10.5072/zenodo.278300",
file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81",
)

disable_socket()
# Ensure that cached files can be accessed when only hash is provided
get_model(
"my_favorite_model.pt",
file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81",
)

# Ensure that cached files raise hash comparison errors
with pytest.raises(HashComparisonFailedException):
get_model(
"my_favorite_model.pt",
doi="10.5072/zenodo.278300",
file_hash="wrong_hash",
)

get_model("openff-gnn-am1bcc-0.1.0-rc.3.pt")


def test_error_on_missing_file(monkeypatch):
with (
pytest.raises(
#
def test_error_on_missing_file():
with pytest.raises(
FileNotFoundError,
match="Could not find asset with name 'FOOBAR",
),
monkeypatch.context() as m,
):
m.setattr(
urllib.request,
"urlretrieve",
mocked_urlretrieve,
)
m.setattr(
openff.nagl_models._dynamic_fetch,
"get_release_metadata",
mocked_get_release_metadata,
)
match="Could not find asset with name 'FOOBAR"):
get_model("FOOBAR.pt")

def test_error_on_bad_file_suffix():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this behavior should not reach out to the internet, it would be nice to have this test run with and without the socket turned off. Same with test_error_on_bad_file_suffix.

with pytest.raises(
BadFileSuffixError,
match="NAGLToolkitWrapper does not recognize file path extension"):

get_model("FOOBAR.txt")

Expand All @@ -150,69 +146,22 @@ def test_error_on_missing_file(monkeypatch):
"openff-gnn-am1bcc-0.1.0-rc.3.pt",
],
)
def test_all_models_loadable(model, monkeypatch):
def test_all_models_loadable(model):
pytest.importorskip("openff.nagl")

from openff.nagl.nn._models import GNNModel

with monkeypatch.context() as m:
m.setattr(
urllib.request,
"urlretrieve",
mocked_urlretrieve,
)
m.setattr(
openff.nagl_models._dynamic_fetch,
"get_release_metadata",
mocked_get_release_metadata,
)

GNNModel.load(get_model(model), eval_mode=True)


def test_get_model_by_doi_and_hash(hide_cache):
# This test uses a Zenodo sandbox DOI (10.5072 prefix) and the corresponding
# SHA256 hash of the test file uploaded to that sandbox record
get_model(
"my_favorite_model.pt",
doi="10.5072/zenodo.278300",
file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81",
)


def test_get_model_by_doi_no_hash(hide_cache):
get_model("my_favorite_model.pt", doi="10.5072/zenodo.278300")


def test_get_model_hash_comparison_fails():
with pytest.raises(HashComparisonFailedException):
get_model(
"my_favorite_model.pt",
doi="10.5072/zenodo.278300",
file_hash="wrong_hash",
)
GNNModel.load(get_model(model), eval_mode=True)


def test_user_provided_hash_conflicts_with_known_hash():
with pytest.raises(HashComparisonFailedException):
get_model("openff-gnn-am1bcc-0.1.0-rc.3.pt", file_hash="wrong_hash")


def test_malformed_doi(monkeypatch, hide_cache):
with monkeypatch.context() as m:
m.setattr(
urllib.request,
"urlretrieve",
mocked_urlretrieve,
)
m.setattr(
openff.nagl_models._dynamic_fetch,
"get_release_metadata",
mocked_get_release_metadata,
)

with pytest.raises(UnableToParseDOIException):
get_model("my_favorite_model.pt", doi="zenodo.278300")
def test_malformed_doi():
with pytest.raises(UnableToParseDOIException):
get_model("nonexistent.pt", doi="zenodo.278300")


def test_no_matching_file_at_doi():
Expand Down
Loading