diff --git a/CHANGELOG.md b/CHANGELOG.md index 70ba306..997e353 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,19 @@ The rules for this file: * accompany each entry with github issue/PR number (Issue #xyz) --> +## Current development + +### Authors + +- @lilyminium +- @Yoshanuikabundi +- @mattwthompson +- @j-wags +- @jaclark5 (assisted with debugging caching issues) + +### New features +- Added fetching by DOI, hash verification, and caching. (#44, #61, #62) + ## v0.3.0 - 2024-07-29 ### Authors diff --git a/openff/nagl_models/__init__.py b/openff/nagl_models/__init__.py index 93c1baf..6638f5b 100644 --- a/openff/nagl_models/__init__.py +++ b/openff/nagl_models/__init__.py @@ -10,7 +10,7 @@ load_nagl_model_directory_entry_points, validate_nagl_model_path, list_available_nagl_models, - get_models_by_type + get_models_by_type, ) from openff.nagl_models._dynamic_fetch import get_model diff --git a/openff/nagl_models/_dynamic_fetch.py b/openff/nagl_models/_dynamic_fetch.py index a8f7616..d30c640 100644 --- a/openff/nagl_models/_dynamic_fetch.py +++ b/openff/nagl_models/_dynamic_fetch.py @@ -6,7 +6,7 @@ import urllib.request import platformdirs from packaging.version import Version - +from openff.nagl_models import validate_nagl_model_path RELEASES_URL = "https://api.github.com/repos/openforcefield/openff-nagl-models/releases" @@ -27,12 +27,16 @@ class HashComparisonFailedException(Exception): class UnableToParseDOIException(Exception): """Exception raised when a Zenodo DOI is unable to be parsed according to the expected pattern.""" +class BadFileSuffixError(Exception): + """Exception raised when a model file with an incorrect suffix is requested (this will happen a + lot with the current working of the ToolkitRegistry.call method, where things like "am1bcc" will + be requested from get_model due to toolkit precedence.""" + def get_release_metadata() -> list[dict]: return json.loads(urllib.request.urlopen(RELEASES_URL).read().decode("utf-8")) -@functools.lru_cache() def get_model( filename: str, doi: None | str = None, @@ -40,9 +44,9 @@ def get_model( ) -> str: """ Return the path of a model as cached on disk, downloading if necessary. The lookup order of this implementation is: - 1. Try to retrieve the file from the local cache - 2. Try to fetch the file from a release of https://github.com/openforcefield/openff-nagl-models - 3. Try to fetch the file from the DOI, if provided + 1. Try to retrieve the file from the installed `openff-nagl-models` python package on disk + 2. Try to retrieve the file from the local cache + 3. Try to fetch the file from the Zenodo DOI, if provided This method will raise an HashComparisonFailedException as soon as a hash mismatch is encountered. So if there's a file with a matching name but a non-matching hash in the local cache, an exception will be raised @@ -68,43 +72,39 @@ def get_model( Returns ------- str - The path to the file if it was found. If the file wasn't found then a FileNotFoundError is rasied. + The path to the file if it was found. If the file wasn't found then a FileNotFoundError is raised. Raises ------ HashComparisonFailedException FileNotFoundError """ - + if not(filename.endswith(".pt")): + raise BadFileSuffixError(f"OpenFF NAGL models are based on PyTorch files and expect a `.pt` suffix. Found an unrecognized file path extension " + f"on {filename=}") pathlib.Path(CACHE_DIR).mkdir(exist_ok=True) - cached_path = CACHE_DIR / filename - + # See if the file has a known hash if file_hash is None and filename in KNOWN_HASHES: file_hash = KNOWN_HASHES[filename] + # See if it's available in the openff-nagl-models python package + try: + file_path = validate_nagl_model_path(filename) + assert_hash_equal(file_path, file_hash) + return file_path + except FileNotFoundError: + pass + + # Then check if it's in the cache + cached_path = CACHE_DIR / filename if cached_path.exists(): if file_hash: assert_hash_equal(cached_path, file_hash) return cached_path.as_posix() - release_metadata = get_release_metadata() - - # tags with "v" prefix can't easily be sorted, but the result of passing through Version - # are not necessarily 1:1 with the metadata in the releases, keep both and map between - releases: dict[Version:str] = { - Version(release["tag_name"]): release for release in release_metadata - } - - for version in reversed(sorted(releases)): - release = releases[version] - for file in release["assets"]: - if file["name"] == filename: - return _download_and_verify_file( - file["browser_download_url"], cached_path, file_hash - ) - + # Otherwise try to fetch from DOI if doi: try: match = re.search(r"10\.(5072|5281)/zenodo\.([0-9]+)", doi) diff --git a/openff/nagl_models/tests/test_dynamic_fetch.py b/openff/nagl_models/tests/test_dynamic_fetch.py index 149ff08..84dfa48 100644 --- a/openff/nagl_models/tests/test_dynamic_fetch.py +++ b/openff/nagl_models/tests/test_dynamic_fetch.py @@ -1,69 +1,18 @@ -import json import os import pathlib import shutil -import urllib.request import platformdirs import pytest -from pytest_socket import SocketBlockedError, disable_socket -import openff.nagl_models._dynamic_fetch from openff.nagl_models import __file__ as root from openff.nagl_models._dynamic_fetch import ( get_model, HashComparisonFailedException, UnableToParseDOIException, + BadFileSuffixError, ) - -def mocked_urlretrieve(url, filename): - """Mock downloading files from assets by copying from the models/ directory.""" - old = ( - pathlib.Path(root).parent / "models/am1bcc" / pathlib.Path(filename).name - ).as_posix() - new = (platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" / filename).as_posix() - - shutil.copy(old, new) - - return new, None - - -def mocked_get_release_metadata(): - # can regenerate this file with - # $ wget https://api.github.com/repos/openforcefield/openff-nagl-models/releases - return json.loads( - open(pathlib.Path(root).parent / "tests/data/releases.json").read() - ) - - -@pytest.mark.parametrize( - "known_model", - [ - "openff-gnn-am1bcc-0.0.1-alpha.1.pt", - "openff-gnn-am1bcc-0.1.0-rc.1.pt", - "openff-gnn-am1bcc-0.1.0-rc.2.pt", - "openff-gnn-am1bcc-0.1.0-rc.3.pt", - ], -) -def test_get_known_models(monkeypatch, known_model): - with monkeypatch.context() as m: - m.setattr( - urllib.request, - "urlretrieve", - mocked_urlretrieve, - ) - m.setattr( - openff.nagl_models._dynamic_fetch, - "get_release_metadata", - mocked_get_release_metadata, - ) - - assert get_model(known_model).endswith(known_model) - - assert "OPENFF_NAGL_MODELS" in get_model(known_model) - - @pytest.fixture def hide_cache(): cache_dir = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" @@ -83,60 +32,100 @@ def hide_cache(): shutil.move(alt_dir, cache_dir) -def test_access_internet_with_empty_cache(hide_cache): - cache_path = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" +def test_zenodo_fetching_and_caching(hide_cache): + """ + All of the tests that rely on remote fetching into the cache + and checking whether something is in the cache need to be run in + serial, otherwise they'll interfere with each other, so they're + all consolidated into this one test. + """ + + # This test uses a Zenodo sandbox DOI (10.5072 prefix) and the corresponding + # SHA256 hash of the test file "my_favorite_model.pt" (which is a copy of + # openff-gnn-am1bcc-0.1.0-rc.3.pt) uploaded to that sandbox record + + from pytest_socket import SocketBlockedError, disable_socket, enable_socket + from openff.nagl_models._dynamic_fetch import CACHE_DIR + from openff.nagl_models import get_nagl_model_dirs_paths disable_socket() - # would be nice to test the FileNotFoundError, but much more difficult to get that - # particular network failure vs. checking that the network is accessed at all + # Ensure that the cache is hidden, + with pytest.raises(FileNotFoundError): + + get_model( + "my_favorite_model.pt", + ) + + # Ensure the test file isn't in the cache or the nagl_models package + assert not (os.path.exists(CACHE_DIR / 'my_favorite_model.pt')) + for dir_path in get_nagl_model_dirs_paths(): + assert not (os.path.exists(dir_path / 'my_favorite_model.pt')) + + # Ensure that trying to fetch a + # model fails due to lack of internet access with pytest.raises( SocketBlockedError, ): - get_model.cache_clear() - - get_model("openff-gnn-am1bcc-0.1.0-rc.3.pt") + get_model( + "my_favorite_model.pt", + doi="10.5072/zenodo.278300", + file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81", + ) + # Ensure that the file can actually be fetched + enable_socket() + get_model( + "my_favorite_model.pt", + doi="10.5072/zenodo.278300", + ) -def test_file_exists_in_cache_without_internet(monkeypatch): - # since tests can run in different orders, make sure the file exists already - with monkeypatch.context() as m: - m.setattr( - urllib.request, - "urlretrieve", - mocked_urlretrieve, - ) - m.setattr( - openff.nagl_models._dynamic_fetch, - "get_release_metadata", - mocked_get_release_metadata, - ) + # Ensure that the file is really in the cache + assert os.path.exists(CACHE_DIR / 'my_favorite_model.pt') + # Ensure that, once fetched, the file can be gotten without accessing the internet. + disable_socket() + # Ensure that cached files can be accessed when no optional arguments are provided + get_model( + "my_favorite_model.pt", + ) - assert get_model("openff-gnn-am1bcc-0.1.0-rc.3.pt") + # Ensure that a network call is not made if the requested file is in the cache + get_model( + "my_favorite_model.pt", + doi="10.5072/zenodo.278300", + ) - disable_socket() + # Ensure that cached files can be accessed when all optional arguments are provided + get_model( + "my_favorite_model.pt", + doi="10.5072/zenodo.278300", + file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81", + ) - get_model("openff-gnn-am1bcc-0.1.0-rc.3.pt") + # Ensure that cached files can be accessed when only hash is provided + get_model( + "my_favorite_model.pt", + file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81", + ) + # Ensure that cached files raise hash comparison errors + with pytest.raises(HashComparisonFailedException): + get_model( + "my_favorite_model.pt", + doi="10.5072/zenodo.278300", + file_hash="wrong_hash", + ) -def test_error_on_missing_file(monkeypatch): - with ( - pytest.raises( +def test_error_on_missing_file(): + with pytest.raises( FileNotFoundError, - match="Could not find asset with name 'FOOBAR", - ), - monkeypatch.context() as m, - ): - m.setattr( - urllib.request, - "urlretrieve", - mocked_urlretrieve, - ) - m.setattr( - openff.nagl_models._dynamic_fetch, - "get_release_metadata", - mocked_get_release_metadata, - ) + match="Could not find asset with name 'FOOBAR"): + get_model("FOOBAR.pt") + +def test_error_on_bad_file_suffix(): + with pytest.raises( + BadFileSuffixError, + match="Found an unrecognized file path extension on filename='FOOBAR.txt'"): get_model("FOOBAR.txt") @@ -150,47 +139,12 @@ def test_error_on_missing_file(monkeypatch): "openff-gnn-am1bcc-0.1.0-rc.3.pt", ], ) -def test_all_models_loadable(model, monkeypatch): +def test_all_models_loadable(model): pytest.importorskip("openff.nagl") from openff.nagl.nn._models import GNNModel - with monkeypatch.context() as m: - m.setattr( - urllib.request, - "urlretrieve", - mocked_urlretrieve, - ) - m.setattr( - openff.nagl_models._dynamic_fetch, - "get_release_metadata", - mocked_get_release_metadata, - ) - - GNNModel.load(get_model(model), eval_mode=True) - - -def test_get_model_by_doi_and_hash(hide_cache): - # This test uses a Zenodo sandbox DOI (10.5072 prefix) and the corresponding - # SHA256 hash of the test file uploaded to that sandbox record - get_model( - "my_favorite_model.pt", - doi="10.5072/zenodo.278300", - file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81", - ) - - -def test_get_model_by_doi_no_hash(hide_cache): - get_model("my_favorite_model.pt", doi="10.5072/zenodo.278300") - - -def test_get_model_hash_comparison_fails(): - with pytest.raises(HashComparisonFailedException): - get_model( - "my_favorite_model.pt", - doi="10.5072/zenodo.278300", - file_hash="wrong_hash", - ) + GNNModel.load(get_model(model), eval_mode=True) def test_user_provided_hash_conflicts_with_known_hash(): @@ -198,21 +152,9 @@ def test_user_provided_hash_conflicts_with_known_hash(): get_model("openff-gnn-am1bcc-0.1.0-rc.3.pt", file_hash="wrong_hash") -def test_malformed_doi(monkeypatch, hide_cache): - with monkeypatch.context() as m: - m.setattr( - urllib.request, - "urlretrieve", - mocked_urlretrieve, - ) - m.setattr( - openff.nagl_models._dynamic_fetch, - "get_release_metadata", - mocked_get_release_metadata, - ) - - with pytest.raises(UnableToParseDOIException): - get_model("my_favorite_model.pt", doi="zenodo.278300") +def test_malformed_doi(): + with pytest.raises(UnableToParseDOIException): + get_model("nonexistent.pt", doi="zenodo.278300") def test_no_matching_file_at_doi():