From 710500e43f24673cb4fd524941c47fb803a203d5 Mon Sep 17 00:00:00 2001 From: Jeffrey Wagner Date: Fri, 27 Jun 2025 12:57:40 -0700 Subject: [PATCH 01/14] initial implementation of zenodo fetching and custom hash checking --- openff/nagl_models/_dynamic_fetch.py | 36 +++++++++++++++--- .../nagl_models/tests/test_dynamic_fetch.py | 37 +++++++++++++++++-- 2 files changed, 65 insertions(+), 8 deletions(-) diff --git a/openff/nagl_models/_dynamic_fetch.py b/openff/nagl_models/_dynamic_fetch.py index a3008d9..701400b 100644 --- a/openff/nagl_models/_dynamic_fetch.py +++ b/openff/nagl_models/_dynamic_fetch.py @@ -1,6 +1,7 @@ import functools import hashlib import json +import re import pathlib import urllib.request @@ -24,14 +25,21 @@ def get_release_metadata() -> list[dict]: @functools.lru_cache() -def get_model(filename: str) -> str: +def get_model(filename: str, + doi: None | str = None, + file_hash: None | str = None) -> str: """Return the path of a model as cached on disk, downloading if necessary.""" pathlib.Path(CACHE_DIR).mkdir(exist_ok=True) cached_path = CACHE_DIR / filename + check_hash = file_hash + if check_hash is None and filename in KNOWN_HASHES: + check_hash = KNOWN_HASHES[filename] + if cached_path.exists(): - assert _get_sha256(cached_path) == KNOWN_HASHES[filename] + if check_hash: + assert _get_sha256(cached_path) == check_hash return cached_path.as_posix() @@ -55,12 +63,30 @@ def get_model(filename: str) -> str: assert cached_path.exists() assert path_to_file == cached_path.as_posix() - assert _get_sha256(cached_path) == KNOWN_HASHES[filename], ( - f"Hash mismatch for {filename}" - ) + if check_hash: + assert _get_sha256(cached_path) == check_hash, ( + f"Hash mismatch for {filename}" + ) return cached_path.as_posix() + if doi: + zenodo_id = re.findall("10.5072/zenodo.([0-9]+)", doi)[0] + + # Remove "sandbox." to convert this to "real" zenodo before merge + # Or keep in with a testing flag? + file_url = f"https://sandbox.zenodo.org/api/records/{zenodo_id}/files/{filename}" + path_to_file, _ = urllib.request.urlretrieve(file_url, + filename=cached_path.as_posix()) + assert cached_path.exists() + assert path_to_file == cached_path.as_posix() + + if check_hash: + assert _get_sha256(cached_path) ==file_hash, ( + f"Hash mismatch for {filename}" + ) + return cached_path.as_posix() + raise FileNotFoundError( f"Could not find asset with name '{filename}' in any release" ) diff --git a/openff/nagl_models/tests/test_dynamic_fetch.py b/openff/nagl_models/tests/test_dynamic_fetch.py index 42ea22e..5beb733 100644 --- a/openff/nagl_models/tests/test_dynamic_fetch.py +++ b/openff/nagl_models/tests/test_dynamic_fetch.py @@ -1,4 +1,5 @@ import json +import os import pathlib import shutil import urllib.request @@ -58,12 +59,29 @@ def test_get_known_models(monkeypatch, known_model): assert "OPENFF_NAGL_MODELS" in get_model(known_model) +@pytest.fixture +def hide_cache(): + cache_dir = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" + alt_dir = str(cache_dir) + "_temp" -def test_access_internet_with_empty_cache(): + if os.path.exists(alt_dir): + raise FileExistsError(f"Temporary directory already exists: {alt_dir}") + + if os.path.exists(cache_dir): + shutil.move(cache_dir, alt_dir) + + yield + + if os.path.exists(alt_dir): + if os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + shutil.move(alt_dir, cache_dir) + +def test_access_internet_with_empty_cache(hide_cache): cache_path = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" - if cache_path.exists(): - shutil.rmtree(cache_path) + #if cache_path.exists(): + # shutil.rmtree(cache_path) disable_socket() @@ -147,3 +165,16 @@ def test_all_models_loadable(model, monkeypatch): ) GNNModel.load(get_model(model), eval_mode=True) + +def test_get_model_by_doi(hide_cache): + get_model("my_favorite_model.pt", + doi="10.5072/zenodo.278300", + file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81") + +def test_get_model_hash_comparison_fails(): + with pytest.raises(AssertionError): + get_model("my_favorite_model.pt", + doi="10.5072/zenodo.278300", + file_hash="wrong_hash") + + From a1d0800b09bda3446db758420b166b721e071513 Mon Sep 17 00:00:00 2001 From: Jeffrey Wagner Date: Fri, 27 Jun 2025 12:59:22 -0700 Subject: [PATCH 02/14] format and clean up --- openff/nagl_models/_dynamic_fetch.py | 28 +++++++++++-------- .../nagl_models/tests/test_dynamic_fetch.py | 23 +++++++-------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/openff/nagl_models/_dynamic_fetch.py b/openff/nagl_models/_dynamic_fetch.py index 701400b..6d5c845 100644 --- a/openff/nagl_models/_dynamic_fetch.py +++ b/openff/nagl_models/_dynamic_fetch.py @@ -20,14 +20,15 @@ CACHE_DIR = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" + def get_release_metadata() -> list[dict]: return json.loads(urllib.request.urlopen(RELEASES_URL).read().decode("utf-8")) @functools.lru_cache() -def get_model(filename: str, - doi: None | str = None, - file_hash: None | str = None) -> str: +def get_model( + filename: str, doi: None | str = None, file_hash: None | str = None +) -> str: """Return the path of a model as cached on disk, downloading if necessary.""" pathlib.Path(CACHE_DIR).mkdir(exist_ok=True) @@ -64,9 +65,9 @@ def get_model(filename: str, assert path_to_file == cached_path.as_posix() if check_hash: - assert _get_sha256(cached_path) == check_hash, ( - f"Hash mismatch for {filename}" - ) + assert ( + _get_sha256(cached_path) == check_hash + ), f"Hash mismatch for {filename}" return cached_path.as_posix() @@ -75,16 +76,19 @@ def get_model(filename: str, # Remove "sandbox." to convert this to "real" zenodo before merge # Or keep in with a testing flag? - file_url = f"https://sandbox.zenodo.org/api/records/{zenodo_id}/files/{filename}" - path_to_file, _ = urllib.request.urlretrieve(file_url, - filename=cached_path.as_posix()) + file_url = ( + f"https://sandbox.zenodo.org/api/records/{zenodo_id}/files/{filename}" + ) + path_to_file, _ = urllib.request.urlretrieve( + file_url, filename=cached_path.as_posix() + ) assert cached_path.exists() assert path_to_file == cached_path.as_posix() if check_hash: - assert _get_sha256(cached_path) ==file_hash, ( - f"Hash mismatch for {filename}" - ) + assert ( + _get_sha256(cached_path) == file_hash + ), f"Hash mismatch for {filename}" return cached_path.as_posix() raise FileNotFoundError( diff --git a/openff/nagl_models/tests/test_dynamic_fetch.py b/openff/nagl_models/tests/test_dynamic_fetch.py index 5beb733..ad90f15 100644 --- a/openff/nagl_models/tests/test_dynamic_fetch.py +++ b/openff/nagl_models/tests/test_dynamic_fetch.py @@ -59,6 +59,7 @@ def test_get_known_models(monkeypatch, known_model): assert "OPENFF_NAGL_MODELS" in get_model(known_model) + @pytest.fixture def hide_cache(): cache_dir = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" @@ -77,12 +78,10 @@ def hide_cache(): shutil.rmtree(cache_dir) shutil.move(alt_dir, cache_dir) + def test_access_internet_with_empty_cache(hide_cache): cache_path = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" - #if cache_path.exists(): - # shutil.rmtree(cache_path) - disable_socket() # would be nice to test the FileNotFoundError, but much more difficult to get that @@ -166,15 +165,17 @@ def test_all_models_loadable(model, monkeypatch): GNNModel.load(get_model(model), eval_mode=True) + def test_get_model_by_doi(hide_cache): - get_model("my_favorite_model.pt", - doi="10.5072/zenodo.278300", - file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81") + get_model( + "my_favorite_model.pt", + doi="10.5072/zenodo.278300", + file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81", + ) + def test_get_model_hash_comparison_fails(): with pytest.raises(AssertionError): - get_model("my_favorite_model.pt", - doi="10.5072/zenodo.278300", - file_hash="wrong_hash") - - + get_model( + "my_favorite_model.pt", doi="10.5072/zenodo.278300", file_hash="wrong_hash" + ) From 68524b638b3fc53bc882d15865a56c900eeff873 Mon Sep 17 00:00:00 2001 From: Jeffrey Wagner Date: Fri, 4 Jul 2025 11:22:23 -0700 Subject: [PATCH 03/14] fill out docstring --- openff/nagl_models/_dynamic_fetch.py | 37 ++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/openff/nagl_models/_dynamic_fetch.py b/openff/nagl_models/_dynamic_fetch.py index 6d5c845..06f57e9 100644 --- a/openff/nagl_models/_dynamic_fetch.py +++ b/openff/nagl_models/_dynamic_fetch.py @@ -27,9 +27,35 @@ def get_release_metadata() -> list[dict]: @functools.lru_cache() def get_model( - filename: str, doi: None | str = None, file_hash: None | str = None + filename: str, doi: None | str = None, file_hash: None | str = None, _sandbox: bool = False ) -> str: - """Return the path of a model as cached on disk, downloading if necessary.""" + """ + Return the path of a model as cached on disk, downloading if necessary. + + Parameters + ---------- + filename : str + The name of the file to search for. + doi : typing.Optional[str], default=None + The Zenodo DOI to use as a backup location for fetching the model file if it's not found in the local cache + or in the + [release metadata of an openff-nagl-models release](https://github.com/openforcefield/openff-nagl-models/releases) + on GitHub. For example: "10.5072/zenodo.278300" + file_hash : typing.Optional[str], default=None + The sha256 hash of the model file to verify the correct contents. Hash checks are automatically performed + on some OpenFF-released NAGL models. But if the model isn't released by OpenFF and this argument is + not provided or has a value of `None`, then no hash check is performed. Raises HashComparisonFailedError if + unsuccessful. + _sandbox : bool, default=False + Whether to connect to sandbox.zenodo.com instead of zenodo.com. Used for testing. + + Returns + ------- + typing.Optional[pathlib.Path] + The path to the file if it was found, otherwise None. + + + """ pathlib.Path(CACHE_DIR).mkdir(exist_ok=True) cached_path = CACHE_DIR / filename @@ -86,9 +112,10 @@ def get_model( assert path_to_file == cached_path.as_posix() if check_hash: - assert ( - _get_sha256(cached_path) == file_hash - ), f"Hash mismatch for {filename}" + actual_hash = _get_sha256(cached_path) + if actual_hash != file_hash: + raise HashComparisonFailedError(f"NAGL model file hash check failed. Expected hash is {file_hash} but computed hash is {actual_hash}") + return cached_path.as_posix() raise FileNotFoundError( From ff6fd70be24a0f3964a9ce89b37b114708eca599 Mon Sep 17 00:00:00 2001 From: Jeffrey Wagner Date: Mon, 7 Jul 2025 11:43:43 -0700 Subject: [PATCH 04/14] Create custom exception for file hash comparison failures and test --- openff/nagl_models/_dynamic_fetch.py | 21 ++++++++++++------- .../nagl_models/tests/test_dynamic_fetch.py | 4 ++-- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/openff/nagl_models/_dynamic_fetch.py b/openff/nagl_models/_dynamic_fetch.py index 06f57e9..4244ec4 100644 --- a/openff/nagl_models/_dynamic_fetch.py +++ b/openff/nagl_models/_dynamic_fetch.py @@ -4,7 +4,7 @@ import re import pathlib import urllib.request - +from openff.utilities.exceptions import OpenFFError import platformdirs from packaging.version import Version @@ -20,6 +20,8 @@ CACHE_DIR = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" +class HashComparisonFailedException(OpenFFError): + """Exception raised when a NAGL file being loaded fails a comparison to a known or user-provided hash.""" def get_release_metadata() -> list[dict]: return json.loads(urllib.request.urlopen(RELEASES_URL).read().decode("utf-8")) @@ -56,6 +58,13 @@ def get_model( """ + + def assert_hash_equal(cached_path, expected_hash): + actual_hash = _get_sha256(cached_path) + if actual_hash != expected_hash: + raise HashComparisonFailedException(f"NAGL model file hash check failed. Expected hash is " + f"{expected_hash} but actual hash is {actual_hash}") + pathlib.Path(CACHE_DIR).mkdir(exist_ok=True) cached_path = CACHE_DIR / filename @@ -66,7 +75,7 @@ def get_model( if cached_path.exists(): if check_hash: - assert _get_sha256(cached_path) == check_hash + assert_hash_equal(cached_path, check_hash) return cached_path.as_posix() @@ -91,9 +100,7 @@ def get_model( assert path_to_file == cached_path.as_posix() if check_hash: - assert ( - _get_sha256(cached_path) == check_hash - ), f"Hash mismatch for {filename}" + assert_hash_equal(cached_path, check_hash) return cached_path.as_posix() @@ -112,9 +119,7 @@ def get_model( assert path_to_file == cached_path.as_posix() if check_hash: - actual_hash = _get_sha256(cached_path) - if actual_hash != file_hash: - raise HashComparisonFailedError(f"NAGL model file hash check failed. Expected hash is {file_hash} but computed hash is {actual_hash}") + assert_hash_equal(cached_path, check_hash) return cached_path.as_posix() diff --git a/openff/nagl_models/tests/test_dynamic_fetch.py b/openff/nagl_models/tests/test_dynamic_fetch.py index ad90f15..dac7f4f 100644 --- a/openff/nagl_models/tests/test_dynamic_fetch.py +++ b/openff/nagl_models/tests/test_dynamic_fetch.py @@ -10,7 +10,7 @@ import openff.nagl_models._dynamic_fetch from openff.nagl_models import __file__ as root -from openff.nagl_models._dynamic_fetch import get_model +from openff.nagl_models._dynamic_fetch import get_model, HashComparisonFailedException def mocked_urlretrieve(url, filename): @@ -175,7 +175,7 @@ def test_get_model_by_doi(hide_cache): def test_get_model_hash_comparison_fails(): - with pytest.raises(AssertionError): + with pytest.raises(HashComparisonFailedException): get_model( "my_favorite_model.pt", doi="10.5072/zenodo.278300", file_hash="wrong_hash" ) From e064711b78a584085b8387e2981f28aed6eae543 Mon Sep 17 00:00:00 2001 From: Jeffrey Wagner Date: Mon, 7 Jul 2025 13:28:52 -0700 Subject: [PATCH 05/14] beef up docstrings and add tests --- openff/nagl_models/_dynamic_fetch.py | 74 ++++++++++++++----- .../nagl_models/tests/test_dynamic_fetch.py | 47 +++++++++++- 2 files changed, 98 insertions(+), 23 deletions(-) diff --git a/openff/nagl_models/_dynamic_fetch.py b/openff/nagl_models/_dynamic_fetch.py index 4244ec4..315df66 100644 --- a/openff/nagl_models/_dynamic_fetch.py +++ b/openff/nagl_models/_dynamic_fetch.py @@ -20,19 +20,36 @@ CACHE_DIR = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" + class HashComparisonFailedException(OpenFFError): """Exception raised when a NAGL file being loaded fails a comparison to a known or user-provided hash.""" + +class UnableToParseDOIException(OpenFFError): + """Exception raised when a Zenodo DOI is unable to be parsed according to the expected pattern.""" + + def get_release_metadata() -> list[dict]: return json.loads(urllib.request.urlopen(RELEASES_URL).read().decode("utf-8")) @functools.lru_cache() def get_model( - filename: str, doi: None | str = None, file_hash: None | str = None, _sandbox: bool = False -) -> str: + filename: str, + doi: None | str = None, + file_hash: None | str = None, + _sandbox: bool = False, +) -> pathlib.Path: """ - Return the path of a model as cached on disk, downloading if necessary. + Return the path of a model as cached on disk, downloading if necessary. The lookup order of this implementation is: + 1. Try to retrieve the file from the local cache + 2. Try to fetch the file from a release of https://github.com/openforcefield/openff-nagl-models + 3. Try to fetch the file from the DOI, if provided + + This method will raise an HashComparisonFailedException as soon as a hash mismatch is encountered. So if + there's a file with a matching name but a non-matching hash in the local cache, an exception will be raised + immediately, even if a file with a matching name that WOULD satisfy the hash check exists in release + metadata or at a provided Zenodo DOI, Parameters ---------- @@ -46,24 +63,30 @@ def get_model( file_hash : typing.Optional[str], default=None The sha256 hash of the model file to verify the correct contents. Hash checks are automatically performed on some OpenFF-released NAGL models. But if the model isn't released by OpenFF and this argument is - not provided or has a value of `None`, then no hash check is performed. Raises HashComparisonFailedError if - unsuccessful. + not provided or has a value of `None`, then no hash check is performed. Raises HashComparisonFailedException + if unsuccessful. If a user provides a hash value here that disagrees with the known hash for the same file + name, the user-provided hash takes precedence. _sandbox : bool, default=False Whether to connect to sandbox.zenodo.com instead of zenodo.com. Used for testing. Returns ------- - typing.Optional[pathlib.Path] - The path to the file if it was found, otherwise None. - + pathlib.Path + The path to the file if it was found. If the file wasn't found then a FileNotFoundError is rasied. + Raises + ------ + HashComparisonFailedException + FileNotFoundError """ def assert_hash_equal(cached_path, expected_hash): actual_hash = _get_sha256(cached_path) if actual_hash != expected_hash: - raise HashComparisonFailedException(f"NAGL model file hash check failed. Expected hash is " - f"{expected_hash} but actual hash is {actual_hash}") + raise HashComparisonFailedException( + f"NAGL model file hash check failed. Expected hash is " + f"{expected_hash} but actual hash is {actual_hash}" + ) pathlib.Path(CACHE_DIR).mkdir(exist_ok=True) @@ -105,16 +128,27 @@ def assert_hash_equal(cached_path, expected_hash): return cached_path.as_posix() if doi: - zenodo_id = re.findall("10.5072/zenodo.([0-9]+)", doi)[0] - - # Remove "sandbox." to convert this to "real" zenodo before merge - # Or keep in with a testing flag? - file_url = ( - f"https://sandbox.zenodo.org/api/records/{zenodo_id}/files/{filename}" - ) - path_to_file, _ = urllib.request.urlretrieve( - file_url, filename=cached_path.as_posix() - ) + try: + zenodo_id = re.findall("10.5072/zenodo.([0-9]+)", doi)[0] + except IndexError: + raise UnableToParseDOIException( + f"Unable to parse Zenodo DOI {doi}. DOI values are expected to look " + f"like '10.5072/zenodo.278300'" + ) + + if _sandbox: + file_url = ( + f"https://sandbox.zenodo.org/api/records/{zenodo_id}/files/{filename}" + ) + else: + file_url = f"https://zenodo.org/api/records/{zenodo_id}/files/{filename}" + + try: + path_to_file, _ = urllib.request.urlretrieve( + file_url, filename=cached_path.as_posix() + ) + except urllib.error.HTTPError: + raise FileNotFoundError(f"No file at {file_url}") assert cached_path.exists() assert path_to_file == cached_path.as_posix() diff --git a/openff/nagl_models/tests/test_dynamic_fetch.py b/openff/nagl_models/tests/test_dynamic_fetch.py index dac7f4f..cacb5a5 100644 --- a/openff/nagl_models/tests/test_dynamic_fetch.py +++ b/openff/nagl_models/tests/test_dynamic_fetch.py @@ -10,7 +10,11 @@ import openff.nagl_models._dynamic_fetch from openff.nagl_models import __file__ as root -from openff.nagl_models._dynamic_fetch import get_model, HashComparisonFailedException +from openff.nagl_models._dynamic_fetch import ( + get_model, + HashComparisonFailedException, + UnableToParseDOIException, +) def mocked_urlretrieve(url, filename): @@ -166,16 +170,53 @@ def test_all_models_loadable(model, monkeypatch): GNNModel.load(get_model(model), eval_mode=True) -def test_get_model_by_doi(hide_cache): +def test_get_model_by_doi_and_hash(hide_cache): get_model( "my_favorite_model.pt", doi="10.5072/zenodo.278300", file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81", + _sandbox=True, ) +def test_get_model_by_doi_no_hash(hide_cache): + get_model("my_favorite_model.pt", doi="10.5072/zenodo.278300", _sandbox=True) + + def test_get_model_hash_comparison_fails(): with pytest.raises(HashComparisonFailedException): get_model( - "my_favorite_model.pt", doi="10.5072/zenodo.278300", file_hash="wrong_hash" + "my_favorite_model.pt", + doi="10.5072/zenodo.278300", + file_hash="wrong_hash", + _sandbox=True, + ) + + +def test_user_provided_hash_conflicts_with_known_hash(): + with pytest.raises(HashComparisonFailedException): + get_model("openff-gnn-am1bcc-0.1.0-rc.3.pt", file_hash="wrong_hash") + + +def test_malformed_doi(monkeypatch, hide_cache): + with monkeypatch.context() as m: + m.setattr( + urllib.request, + "urlretrieve", + mocked_urlretrieve, + ) + m.setattr( + openff.nagl_models._dynamic_fetch, + "get_release_metadata", + mocked_get_release_metadata, + ) + + with pytest.raises(UnableToParseDOIException): + get_model("my_favorite_model.pt", doi="zenodo.278300", _sandbox=True) + + +def test_no_matching_file_at_doi(): + with pytest.raises(FileNotFoundError, match="sandbox.zenodo"): + get_model( + "file_that_doesnt_exist.pt", doi="10.5072/zenodo.278300", _sandbox=True ) From f9856dbef4794c9b36f4d029003fee2ed1708dc0 Mon Sep 17 00:00:00 2001 From: Jeffrey Wagner Date: Mon, 7 Jul 2025 14:00:52 -0700 Subject: [PATCH 06/14] cosmetic commit to kick ci --- openff/nagl_models/_dynamic_fetch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openff/nagl_models/_dynamic_fetch.py b/openff/nagl_models/_dynamic_fetch.py index 315df66..753603e 100644 --- a/openff/nagl_models/_dynamic_fetch.py +++ b/openff/nagl_models/_dynamic_fetch.py @@ -49,7 +49,7 @@ def get_model( This method will raise an HashComparisonFailedException as soon as a hash mismatch is encountered. So if there's a file with a matching name but a non-matching hash in the local cache, an exception will be raised immediately, even if a file with a matching name that WOULD satisfy the hash check exists in release - metadata or at a provided Zenodo DOI, + metadata or at a provided Zenodo DOI. Parameters ---------- From 57f222e45e754cd187617a2c044a7a14a2d76490 Mon Sep 17 00:00:00 2001 From: Jeffrey Wagner Date: Mon, 7 Jul 2025 14:05:52 -0700 Subject: [PATCH 07/14] remove dependency on openff-utilities --- openff/nagl_models/_dynamic_fetch.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/openff/nagl_models/_dynamic_fetch.py b/openff/nagl_models/_dynamic_fetch.py index 753603e..636116b 100644 --- a/openff/nagl_models/_dynamic_fetch.py +++ b/openff/nagl_models/_dynamic_fetch.py @@ -4,7 +4,6 @@ import re import pathlib import urllib.request -from openff.utilities.exceptions import OpenFFError import platformdirs from packaging.version import Version @@ -21,11 +20,11 @@ CACHE_DIR = platformdirs.user_cache_path() / "OPENFF_NAGL_MODELS" -class HashComparisonFailedException(OpenFFError): +class HashComparisonFailedException(Exception): """Exception raised when a NAGL file being loaded fails a comparison to a known or user-provided hash.""" -class UnableToParseDOIException(OpenFFError): +class UnableToParseDOIException(Exception): """Exception raised when a Zenodo DOI is unable to be parsed according to the expected pattern.""" From 9522d571dc8f883b4ba08427d08f8b092218c0b7 Mon Sep 17 00:00:00 2001 From: Jeffrey Wagner Date: Fri, 11 Jul 2025 08:47:04 -0700 Subject: [PATCH 08/14] remove redundant typing from docstring --- openff/nagl_models/_dynamic_fetch.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/openff/nagl_models/_dynamic_fetch.py b/openff/nagl_models/_dynamic_fetch.py index 636116b..f008417 100644 --- a/openff/nagl_models/_dynamic_fetch.py +++ b/openff/nagl_models/_dynamic_fetch.py @@ -52,25 +52,25 @@ def get_model( Parameters ---------- - filename : str + filename The name of the file to search for. - doi : typing.Optional[str], default=None + doi The Zenodo DOI to use as a backup location for fetching the model file if it's not found in the local cache or in the [release metadata of an openff-nagl-models release](https://github.com/openforcefield/openff-nagl-models/releases) on GitHub. For example: "10.5072/zenodo.278300" - file_hash : typing.Optional[str], default=None + file_hash The sha256 hash of the model file to verify the correct contents. Hash checks are automatically performed on some OpenFF-released NAGL models. But if the model isn't released by OpenFF and this argument is not provided or has a value of `None`, then no hash check is performed. Raises HashComparisonFailedException if unsuccessful. If a user provides a hash value here that disagrees with the known hash for the same file name, the user-provided hash takes precedence. - _sandbox : bool, default=False + _sandbox Whether to connect to sandbox.zenodo.com instead of zenodo.com. Used for testing. Returns ------- - pathlib.Path + str The path to the file if it was found. If the file wasn't found then a FileNotFoundError is rasied. Raises From f560c1fd510ed33d5479a00abc4e02cee4cc078f Mon Sep 17 00:00:00 2001 From: Jeffrey Wagner Date: Fri, 11 Jul 2025 08:48:07 -0700 Subject: [PATCH 09/14] fix return type --- openff/nagl_models/_dynamic_fetch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openff/nagl_models/_dynamic_fetch.py b/openff/nagl_models/_dynamic_fetch.py index f008417..7535855 100644 --- a/openff/nagl_models/_dynamic_fetch.py +++ b/openff/nagl_models/_dynamic_fetch.py @@ -38,7 +38,7 @@ def get_model( doi: None | str = None, file_hash: None | str = None, _sandbox: bool = False, -) -> pathlib.Path: +) -> str: """ Return the path of a model as cached on disk, downloading if necessary. The lookup order of this implementation is: 1. Try to retrieve the file from the local cache From c6eb3119db8b2526997334274aa3c639ec8f3fb5 Mon Sep 17 00:00:00 2001 From: Jeffrey Wagner Date: Fri, 11 Jul 2025 08:48:28 -0700 Subject: [PATCH 10/14] don't define unnecessary variable check_hash --- openff/nagl_models/_dynamic_fetch.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/openff/nagl_models/_dynamic_fetch.py b/openff/nagl_models/_dynamic_fetch.py index 7535855..c1235ee 100644 --- a/openff/nagl_models/_dynamic_fetch.py +++ b/openff/nagl_models/_dynamic_fetch.py @@ -91,13 +91,12 @@ def assert_hash_equal(cached_path, expected_hash): cached_path = CACHE_DIR / filename - check_hash = file_hash - if check_hash is None and filename in KNOWN_HASHES: - check_hash = KNOWN_HASHES[filename] + if file_hash is None and filename in KNOWN_HASHES: + file_hash = KNOWN_HASHES[filename] if cached_path.exists(): - if check_hash: - assert_hash_equal(cached_path, check_hash) + if file_hash: + assert_hash_equal(cached_path, file_hash) return cached_path.as_posix() @@ -121,8 +120,8 @@ def assert_hash_equal(cached_path, expected_hash): assert cached_path.exists() assert path_to_file == cached_path.as_posix() - if check_hash: - assert_hash_equal(cached_path, check_hash) + if file_hash: + assert_hash_equal(cached_path, file_hash) return cached_path.as_posix() @@ -151,8 +150,8 @@ def assert_hash_equal(cached_path, expected_hash): assert cached_path.exists() assert path_to_file == cached_path.as_posix() - if check_hash: - assert_hash_equal(cached_path, check_hash) + if file_hash: + assert_hash_equal(cached_path, file_hash) return cached_path.as_posix() From 1d6789979a2ae2431eaf7188f6f4d8843d2bacf4 Mon Sep 17 00:00:00 2001 From: Jeffrey Wagner Date: Fri, 11 Jul 2025 10:47:19 -0700 Subject: [PATCH 11/14] infer sandbox-ness by checking doi prefix --- openff/nagl_models/_dynamic_fetch.py | 14 +++++++------- openff/nagl_models/tests/test_dynamic_fetch.py | 8 +++----- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/openff/nagl_models/_dynamic_fetch.py b/openff/nagl_models/_dynamic_fetch.py index c1235ee..99483e4 100644 --- a/openff/nagl_models/_dynamic_fetch.py +++ b/openff/nagl_models/_dynamic_fetch.py @@ -37,7 +37,6 @@ def get_model( filename: str, doi: None | str = None, file_hash: None | str = None, - _sandbox: bool = False, ) -> str: """ Return the path of a model as cached on disk, downloading if necessary. The lookup order of this implementation is: @@ -65,8 +64,6 @@ def get_model( not provided or has a value of `None`, then no hash check is performed. Raises HashComparisonFailedException if unsuccessful. If a user provides a hash value here that disagrees with the known hash for the same file name, the user-provided hash takes precedence. - _sandbox - Whether to connect to sandbox.zenodo.com instead of zenodo.com. Used for testing. Returns ------- @@ -127,14 +124,17 @@ def assert_hash_equal(cached_path, expected_hash): if doi: try: - zenodo_id = re.findall("10.5072/zenodo.([0-9]+)", doi)[0] - except IndexError: + match = re.search(r"10\.(5072|5281)/zenodo\.([0-9]+)", doi) + if not match: + raise IndexError + prefix, zenodo_id = match.groups() + except (IndexError, AttributeError): raise UnableToParseDOIException( f"Unable to parse Zenodo DOI {doi}. DOI values are expected to look " - f"like '10.5072/zenodo.278300'" + f"like '10.5281/zenodo.278300' (production) or '10.5072/zenodo.278300' (sandbox)" ) - if _sandbox: + if prefix == "5072": file_url = ( f"https://sandbox.zenodo.org/api/records/{zenodo_id}/files/{filename}" ) diff --git a/openff/nagl_models/tests/test_dynamic_fetch.py b/openff/nagl_models/tests/test_dynamic_fetch.py index cacb5a5..98038e1 100644 --- a/openff/nagl_models/tests/test_dynamic_fetch.py +++ b/openff/nagl_models/tests/test_dynamic_fetch.py @@ -175,12 +175,11 @@ def test_get_model_by_doi_and_hash(hide_cache): "my_favorite_model.pt", doi="10.5072/zenodo.278300", file_hash="127eb0b9512f22546f8b455582bcd85b2521866d32b86d231fee26d4771b1d81", - _sandbox=True, ) def test_get_model_by_doi_no_hash(hide_cache): - get_model("my_favorite_model.pt", doi="10.5072/zenodo.278300", _sandbox=True) + get_model("my_favorite_model.pt", doi="10.5072/zenodo.278300") def test_get_model_hash_comparison_fails(): @@ -189,7 +188,6 @@ def test_get_model_hash_comparison_fails(): "my_favorite_model.pt", doi="10.5072/zenodo.278300", file_hash="wrong_hash", - _sandbox=True, ) @@ -212,11 +210,11 @@ def test_malformed_doi(monkeypatch, hide_cache): ) with pytest.raises(UnableToParseDOIException): - get_model("my_favorite_model.pt", doi="zenodo.278300", _sandbox=True) + get_model("my_favorite_model.pt", doi="zenodo.278300") def test_no_matching_file_at_doi(): with pytest.raises(FileNotFoundError, match="sandbox.zenodo"): get_model( - "file_that_doesnt_exist.pt", doi="10.5072/zenodo.278300", _sandbox=True + "file_that_doesnt_exist.pt", doi="10.5072/zenodo.278300" ) From 4c00163f7a8258c1812d5972a42b0a0cc6451833 Mon Sep 17 00:00:00 2001 From: Jeffrey Wagner Date: Fri, 11 Jul 2025 11:02:30 -0700 Subject: [PATCH 12/14] avoid duplicating downloading and hash checking code --- openff/nagl_models/_dynamic_fetch.py | 54 ++++++++++++++-------------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/openff/nagl_models/_dynamic_fetch.py b/openff/nagl_models/_dynamic_fetch.py index 99483e4..fd4fc8c 100644 --- a/openff/nagl_models/_dynamic_fetch.py +++ b/openff/nagl_models/_dynamic_fetch.py @@ -76,13 +76,6 @@ def get_model( FileNotFoundError """ - def assert_hash_equal(cached_path, expected_hash): - actual_hash = _get_sha256(cached_path) - if actual_hash != expected_hash: - raise HashComparisonFailedException( - f"NAGL model file hash check failed. Expected hash is " - f"{expected_hash} but actual hash is {actual_hash}" - ) pathlib.Path(CACHE_DIR).mkdir(exist_ok=True) @@ -109,19 +102,10 @@ def assert_hash_equal(cached_path, expected_hash): release = releases[version] for file in release["assets"]: if file["name"] == filename: - path_to_file, _ = urllib.request.urlretrieve( - url=file["browser_download_url"], - filename=cached_path.as_posix(), + return _download_and_verify_file( + file["browser_download_url"], cached_path, file_hash ) - assert cached_path.exists() - assert path_to_file == cached_path.as_posix() - - if file_hash: - assert_hash_equal(cached_path, file_hash) - - return cached_path.as_posix() - if doi: try: match = re.search(r"10\.(5072|5281)/zenodo\.([0-9]+)", doi) @@ -142,24 +126,38 @@ def assert_hash_equal(cached_path, expected_hash): file_url = f"https://zenodo.org/api/records/{zenodo_id}/files/{filename}" try: - path_to_file, _ = urllib.request.urlretrieve( - file_url, filename=cached_path.as_posix() - ) + return _download_and_verify_file(file_url, cached_path, file_hash) except urllib.error.HTTPError: raise FileNotFoundError(f"No file at {file_url}") - assert cached_path.exists() - assert path_to_file == cached_path.as_posix() - - if file_hash: - assert_hash_equal(cached_path, file_hash) - - return cached_path.as_posix() raise FileNotFoundError( f"Could not find asset with name '{filename}' in any release" ) +def assert_hash_equal(cached_path, expected_hash): + actual_hash = _get_sha256(cached_path) + if actual_hash != expected_hash: + raise HashComparisonFailedException( + f"NAGL model file hash check failed. Expected hash is " + f"{expected_hash} but actual hash is {actual_hash}" + ) + +def _download_and_verify_file(url: str, cached_path: pathlib.Path, file_hash: None | str = None) -> str: + """Download a file from URL to cached_path and optionally verify its hash.""" + path_to_file, _ = urllib.request.urlretrieve( + url, filename=cached_path.as_posix() + ) + + assert cached_path.exists() + assert path_to_file == cached_path.as_posix() + + if file_hash: + assert_hash_equal(cached_path, file_hash) + + return cached_path.as_posix() + + def _get_sha256(filename: str) -> str: """Get the SHA256 hash of a file from its path, assuming it's a binary file like a PyTorch model.""" hash = hashlib.sha256() From 4389a51d97b3d63c7003901792b8eb98c2bb5a05 Mon Sep 17 00:00:00 2001 From: Jeffrey Wagner Date: Fri, 11 Jul 2025 11:09:06 -0700 Subject: [PATCH 13/14] clarify where zenodo record came from --- openff/nagl_models/tests/test_dynamic_fetch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/openff/nagl_models/tests/test_dynamic_fetch.py b/openff/nagl_models/tests/test_dynamic_fetch.py index 98038e1..154d3f8 100644 --- a/openff/nagl_models/tests/test_dynamic_fetch.py +++ b/openff/nagl_models/tests/test_dynamic_fetch.py @@ -171,6 +171,8 @@ def test_all_models_loadable(model, monkeypatch): def test_get_model_by_doi_and_hash(hide_cache): + # This test uses a Zenodo sandbox DOI (10.5072 prefix) and the corresponding + # SHA256 hash of the test file uploaded to that sandbox record get_model( "my_favorite_model.pt", doi="10.5072/zenodo.278300", From d5eb021168e68a8423f212e0adf481f0daabdae7 Mon Sep 17 00:00:00 2001 From: Jeffrey Wagner Date: Fri, 11 Jul 2025 11:10:59 -0700 Subject: [PATCH 14/14] black --- openff/nagl_models/_dynamic_fetch.py | 10 +++++----- openff/nagl_models/openff_nagl_models.py | 16 ++++++++-------- openff/nagl_models/tests/test_dynamic_fetch.py | 4 +--- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/openff/nagl_models/_dynamic_fetch.py b/openff/nagl_models/_dynamic_fetch.py index fd4fc8c..a8f7616 100644 --- a/openff/nagl_models/_dynamic_fetch.py +++ b/openff/nagl_models/_dynamic_fetch.py @@ -76,7 +76,6 @@ def get_model( FileNotFoundError """ - pathlib.Path(CACHE_DIR).mkdir(exist_ok=True) cached_path = CACHE_DIR / filename @@ -143,11 +142,12 @@ def assert_hash_equal(cached_path, expected_hash): f"{expected_hash} but actual hash is {actual_hash}" ) -def _download_and_verify_file(url: str, cached_path: pathlib.Path, file_hash: None | str = None) -> str: + +def _download_and_verify_file( + url: str, cached_path: pathlib.Path, file_hash: None | str = None +) -> str: """Download a file from URL to cached_path and optionally verify its hash.""" - path_to_file, _ = urllib.request.urlretrieve( - url, filename=cached_path.as_posix() - ) + path_to_file, _ = urllib.request.urlretrieve(url, filename=cached_path.as_posix()) assert cached_path.exists() assert path_to_file == cached_path.as_posix() diff --git a/openff/nagl_models/openff_nagl_models.py b/openff/nagl_models/openff_nagl_models.py index 23f1e68..dd55b84 100644 --- a/openff/nagl_models/openff_nagl_models.py +++ b/openff/nagl_models/openff_nagl_models.py @@ -2,6 +2,7 @@ This module only contains the function that will be the entry point that will be used to find the model files. """ + import importlib.resources import os import pathlib @@ -166,7 +167,8 @@ def list_available_nagl_models() -> list[pathlib.Path]: # look for all .pt files in the cache directory, but only those that are # expected to also be found in release assets cached_paths = [ - cached_file for cached_file in CACHE_DIR.rglob("*.pt") + cached_file + for cached_file in CACHE_DIR.rglob("*.pt") if cached_file.name in KNOWN_HASHES ] @@ -205,12 +207,12 @@ def get_models_by_type( -------- Getting the latest pre-release model for am1bcc:: - + >>> from openff.nagl_models.openff_nagl_models import get_models_by_type >>> get_models_by_type(model_type="am1bcc") [PosixPath('/.../openff-nagl-models/openff/nagl_models/models/am1bcc/openff-gnn-am1bcc-0.0.1-alpha.1.pt'), PosixPath('/.../openff-nagl-models/openff/nagl_models/models/am1bcc/openff-gnn-am1bcc-0.1.0-rc.1.pt')] - + """ from packaging.version import Version @@ -221,14 +223,12 @@ def get_models_by_type( "If you are using a custom model, " "please manually specify the path to the model file." ) - + model_files = pathlib.Path(base_dir).glob("*.pt") - + # assume everything follows the openff-gnn--.pt format n_name = len(f"openff-gnn-{model_type}-") - versions_to_paths = { - Version(f.stem[n_name:]): f for f in model_files - } + versions_to_paths = {Version(f.stem[n_name:]): f for f in model_files} versions = sorted(versions_to_paths.keys()) if production_only: versions = [v for v in versions if not v.is_prerelease] diff --git a/openff/nagl_models/tests/test_dynamic_fetch.py b/openff/nagl_models/tests/test_dynamic_fetch.py index 154d3f8..149ff08 100644 --- a/openff/nagl_models/tests/test_dynamic_fetch.py +++ b/openff/nagl_models/tests/test_dynamic_fetch.py @@ -217,6 +217,4 @@ def test_malformed_doi(monkeypatch, hide_cache): def test_no_matching_file_at_doi(): with pytest.raises(FileNotFoundError, match="sandbox.zenodo"): - get_model( - "file_that_doesnt_exist.pt", doi="10.5072/zenodo.278300" - ) + get_model("file_that_doesnt_exist.pt", doi="10.5072/zenodo.278300")