diff --git a/fireworks/core/firework.py b/fireworks/core/firework.py index 7bed247ce..f695d4d1f 100644 --- a/fireworks/core/firework.py +++ b/fireworks/core/firework.py @@ -782,6 +782,10 @@ def __init__( created_on (datetime): time of creation updated_on (datetime): time of update fw_states (dict): leave this alone unless you are purposefully creating a Lazy-style WF. + + Raises: + ValueError: when Firework IDs are duplicated or inconsistent, + or links dictionary is invalid """ name = name or "unnamed WF" # prevent None names @@ -894,6 +898,9 @@ def apply_action(self, action: FWAction, fw_id: int) -> list[int]: Returns: list[int]: list of Firework ids that were updated or new. + + Raises: + ValueError: when duplicated Firework IDs are found in additions or detours """ updated_ids = [] @@ -1021,6 +1028,10 @@ def append_wf(self, new_wf, fw_ids, detour=False, pull_spec_mods=False): Returns: list[int]: list of Firework ids that were updated or new. + + Raises: + TypeError: when detour is not boolean + ValueError: when detour or fw_ids inputs are invalid """ updated_ids = [] diff --git a/fireworks/core/launchpad.py b/fireworks/core/launchpad.py index 401976df2..9cf9ce4dc 100644 --- a/fireworks/core/launchpad.py +++ b/fireworks/core/launchpad.py @@ -33,6 +33,7 @@ WFLOCK_EXPIRATION_SECS, MongoClient, ) +from fireworks.utilities.exceptions import FWValueError from fireworks.utilities.fw_serializers import FWSerializable, reconstitute_dates, recursive_dict from fireworks.utilities.fw_utilities import get_fw_logger @@ -312,6 +313,9 @@ def reset(self, password, require_password=True, max_reset_wo_password=25) -> No max_reset_wo_password to minimize risk. max_reset_wo_password (int): A failsafe; when require_password is set to False, FWS will not clear DBs that contain more workflows than this parameter + + Raises: + ValueError: in case of invalid password or failed password override """ m_password = datetime.datetime.now().strftime("%Y-%m-%d") @@ -461,6 +465,9 @@ def get_launch_by_id(self, launch_id): Returns: Launch object + + Raises: + ValueError: in case of invalid launch_id """ m_launch = self.launches.find_one({"launch_id": launch_id}) if m_launch: @@ -476,6 +483,9 @@ def get_fw_dict_by_id(self, fw_id): Returns: dict + + Raises: + ValueError: in case of invalid fw_ids """ fw_dict = self.fireworks.find_one({"fw_id": fw_id}) if not fw_dict: @@ -514,6 +524,9 @@ def get_wf_by_fw_id(self, fw_id): Returns: A Workflow object + + Raises: + ValueError: in case of invalid fw_id """ links_dict = self.workflows.find_one({"nodes": fw_id}) if not links_dict: @@ -536,6 +549,9 @@ def get_wf_by_fw_id_lzyfw(self, fw_id: int) -> Workflow: Returns: A Workflow object + + Raises: + ValueError: in case of invalid fw_id """ links_dict = self.workflows.find_one({"nodes": fw_id}) if not links_dict: @@ -1561,6 +1577,9 @@ def complete_launch(self, launch_id, action=None, state="COMPLETED"): Returns: dict: updated launch + + Raises: + DocumentTooLarge: in some cases when the document size limit is exceeded """ # update the launch data to COMPLETED, set end time, etc m_launch = self.get_launch_by_id(launch_id) @@ -1628,6 +1647,9 @@ def get_new_fw_id(self, quantity=1): Args: quantity (int): optionally ask for many ids, otherwise defaults to 1 this then returns the *first* fw_id in that range + + Raises: + ValueError: if next Firework id cannot be found """ try: return self.fw_id_assigner.find_one_and_update({}, {"$inc": {"next_fw_id": quantity}})["next_fw_id"] @@ -1799,6 +1821,10 @@ def _refresh_wf(self, fw_id) -> None: Args: fw_id (int): the parent fw_id - children will be refreshed + + Raises: + RuntimeError: in case of an error when refreshing the workflow + different from LockedWorkflowError """ # TODO: time how long it took to refresh the WF! # TODO: need a try-except here, high probability of failure if incorrect action supplied @@ -1827,6 +1853,9 @@ def _update_wf(self, wf, updated_ids) -> None: Args: wf (Workflow) updated_ids ([int]): list of firework ids + + Raises: + FWValueError: when the query finds no matching workflow """ updated_fws = [wf.id_fw[fid] for fid in updated_ids] old_new = self._upsert_fws(updated_fws) @@ -1841,7 +1870,7 @@ def _update_wf(self, wf, updated_ids) -> None: assert query_node is not None if not self.workflows.find_one({"nodes": query_node}): - raise ValueError(f"BAD QUERY_NODE! {query_node}") + raise FWValueError(f"BAD QUERY_NODE! {query_node}") # redo the links and fw_states wf = wf.to_db_dict() wf["locked"] = True # preserve the lock! diff --git a/fireworks/core/rocket.py b/fireworks/core/rocket.py index 504ef68a7..0fbe56cc6 100644 --- a/fireworks/core/rocket.py +++ b/fireworks/core/rocket.py @@ -126,6 +126,9 @@ def run(self, pdb_on_exception: bool = False, err_file: IO = None) -> bool: Returns: bool: True if the rocket ran successfully, False is if it failed or no job in the DB was ready to run. + + Raises: + OSError: when creation of launch directory fails """ all_stored_data = {} # combined stored data for *all* the Tasks all_update_spec = {} # combined update_spec for *all* the Tasks diff --git a/fireworks/fw_config.py b/fireworks/fw_config.py index 81ebbdba0..ebfcc815b 100644 --- a/fireworks/fw_config.py +++ b/fireworks/fw_config.py @@ -9,6 +9,7 @@ import pymongo from monty.design_patterns import singleton from monty.serialization import dumpfn, loadfn +from fireworks.utilities.exceptions import FWConfigurationError __author__ = "Anubhav Jain" __copyright__ = "Copyright 2012, The Materials Project" @@ -137,13 +138,15 @@ def override_user_settings() -> None: if os.path.exists(config_paths[0]): overrides = loadfn(config_paths[0]) + if not isinstance(overrides, dict): + raise FWConfigurationError(f"Invalid FW_config file, type must be dict but is {type(overrides)}") for key, v in overrides.items(): if key == "ADD_USER_PACKAGES": USER_PACKAGES.extend(v) elif key == "ECHO_TEST": print(v) elif key not in globals(): - raise ValueError(f"Invalid FW_config file has unknown parameter: {key}") + raise FWConfigurationError(f"Invalid FW_config file has unknown parameter: {key}") else: globals()[key] = v diff --git a/fireworks/queue/queue_launcher.py b/fireworks/queue/queue_launcher.py index 621a1542e..190b9e1b4 100644 --- a/fireworks/queue/queue_launcher.py +++ b/fireworks/queue/queue_launcher.py @@ -62,6 +62,10 @@ def launch_rocket_to_queue( fill_mode (bool): whether to submit jobs even when there is nothing to run (only in non-reservation mode) fw_id (int): specific fw_id to reserve (reservation mode only) + + Raises: + ValueError: in case of invalid or inconsistent parameters + RuntimeError: when job script submission fails """ fworker = fworker or FWorker() launcher_dir = os.path.abspath(launcher_dir) @@ -204,6 +208,9 @@ def rapidfire( timeout (int): Number of seconds after which to stop the rapidfire process fill_mode (bool): Whether to submit jobs even when there is nothing to run (only in non-reservation mode) + Raises: + ValueError: when launch_dir does not exist or block_dir has invalid name + RuntimeError: when launch fails """ sleep_time = sleep_time or RAPIDFIRE_SLEEP_SECS launch_dir = os.path.abspath(launch_dir) @@ -315,6 +322,9 @@ def _get_number_of_jobs_in_queue(qadapter: "QueueAdapterBase", njobs_queue: int, Return: (int) + + Raises: + RuntimeError: if determination of number of jobs fails """ RETRY_INTERVAL = 30 # initial retry in 30 sec upon failure diff --git a/fireworks/scripts/lpad_run.py b/fireworks/scripts/lpad_run.py index f5144cce2..448716cbd 100644 --- a/fireworks/scripts/lpad_run.py +++ b/fireworks/scripts/lpad_run.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Any from pymongo import ASCENDING, DESCENDING +from pymongo.errors import PyMongoError from ruamel.yaml import YAML from fireworks import FW_INSTALL_DIR @@ -124,7 +125,7 @@ def get_lp(args: Namespace) -> LaunchPad: lp.connection.admin.command("ping") return lp - except Exception: + except PyMongoError: err_message = ( f"FireWorks was not able to connect to MongoDB at {lp.host}:{lp.port}. Is the server running? " f"The database file specified was {args.launchpad_file}." diff --git a/fireworks/tests/test_fw_config.py b/fireworks/tests/test_fw_config.py index 6602a6584..efcd8a0bd 100644 --- a/fireworks/tests/test_fw_config.py +++ b/fireworks/tests/test_fw_config.py @@ -4,12 +4,48 @@ __email__ = "ongsp@ucsd.edu" __date__ = "2/3/14" +import os import unittest +from tempfile import mkdtemp -from fireworks.fw_config import config_to_dict +import pytest + +from fireworks.fw_config import config_to_dict, override_user_settings +from fireworks.utilities.exceptions import FWConfigurationError class ConfigTest(unittest.TestCase): def test_config(self) -> None: d = config_to_dict() assert "NEGATIVE_FWID_CTR" not in d + + +class FWConfigTest(unittest.TestCase): + """Tests for the fw_config module.""" + + def setUp(self): + self.init_dir = os.getcwd() + self.fw_config_dir = mkdtemp() + os.chdir(self.fw_config_dir) + self.fw_config = os.path.join(self.fw_config_dir, "FW_config.yaml") + with open(self.fw_config, "w", encoding="utf-8"): + pass + + def tearDown(self): + os.chdir(self.init_dir) + os.unlink(self.fw_config) + os.rmdir(self.fw_config_dir) + + def test_override_user_settings_empty_yaml(self) -> None: + """Test with empty fw_config file.""" + msg = "Invalid FW_config file, type must be dict but is " + with pytest.raises(FWConfigurationError, match=msg): + override_user_settings() + + def test_override_user_settings_invalid_key(self) -> None: + """Test fw_config file with invalid key.""" + with open(self.fw_config, "a", encoding="utf-8") as fh: + fh.write("blah: true") + msg = "Invalid FW_config file has unknown parameter: blah" + with pytest.raises(FWConfigurationError, match=msg): + override_user_settings() diff --git a/fireworks/utilities/exceptions.py b/fireworks/utilities/exceptions.py new file mode 100644 index 000000000..1c13ac2e4 --- /dev/null +++ b/fireworks/utilities/exceptions.py @@ -0,0 +1,21 @@ +"""FireWorks exceptions.""" + + +class FWError(Exception): + """Base exception for all other FireWorks exceptions.""" + + +class FWConfigurationError(FWError): + """Raise for errors related to fw_config.""" + + +class FWSerializationError(FWError): + """Raise for errors related to serialization/deserialization.""" + + +class FWFormatError(FWError): + """Raise for errors related to file format.""" + + +class FWValueError(FWError, ValueError): + """FireWorks specialization of ValueError.""" diff --git a/fireworks/utilities/fw_serializers.py b/fireworks/utilities/fw_serializers.py index 6a639c1b2..ae0726e82 100644 --- a/fireworks/utilities/fw_serializers.py +++ b/fireworks/utilities/fw_serializers.py @@ -50,6 +50,7 @@ YAML_STYLE, ) from fireworks.utilities.fw_utilities import get_fw_logger +from fireworks.utilities.exceptions import FWSerializationError, FWFormatError __author__ = "Anubhav Jain" __copyright__ = "Copyright 2012, The Materials Project" @@ -73,7 +74,7 @@ def DATETIME_HANDLER(obj): import numpy as np NUMPY_INSTALLED = True -except Exception: +except ModuleNotFoundError: NUMPY_INSTALLED = False if JSON_SCHEMA_VALIDATE: @@ -234,6 +235,9 @@ def to_format(self, f_format="json", **kwargs): Args: f_format (str): the format to output to (default json) **kwargs: additional keyword arguments passed to the serializer + + Raises: + FWFormatError: when f_format is not supported """ if f_format == "json": return json.dumps(self.to_dict(), default=DATETIME_HANDLER, **kwargs) @@ -244,7 +248,7 @@ def to_format(self, f_format="json", **kwargs): strm = StringIO() yaml.dump(self.to_dict(), strm) return strm.getvalue() - raise ValueError(f"Unsupported format {f_format}") + raise FWFormatError(f"Unsupported format {f_format}") @classmethod def from_format(cls, f_str, f_format="json"): @@ -262,7 +266,9 @@ def from_format(cls, f_str, f_format="json"): elif f_format == "yaml": dct = YAML(typ="safe", pure=True).load(f_str) else: - raise ValueError(f"Unsupported format {f_format}") + raise FWFormatError(f"Unsupported format {f_format}") + if not isinstance(dct, dict): + raise FWSerializationError(f"Serialized object must be a dict but is {type(dct)}") if JSON_SCHEMA_VALIDATE and cls.__name__ in JSON_SCHEMA_VALIDATE_LIST: fireworks_schema.validate(dct, cls.__name__) return cls.from_dict(reconstitute_dates(dct)) @@ -279,7 +285,7 @@ def to_file(self, filename, f_format=None, **kwargs) -> None: if f_format is None: f_format = filename.split(".")[-1] if f_format not in ("json", "yaml"): - raise ValueError(f"Unsupported format {f_format}") + raise FWFormatError(f"Unsupported format {f_format}") with open(filename, "w", **ENCODING_PARAMS) as f_out: if f_format == "json": json.dump(dct, f_out, default=DATETIME_HANDLER, **kwargs) @@ -338,6 +344,9 @@ def load_object(obj_dict): Args: obj_dict (dict): the dict representation of the class + + Raises: + FWSerializationError: in case none or multiple classes match _fw_name """ # override the name in the obj_dict if there's an entry in FW_NAME_UPDATES fw_name = FW_NAME_UPDATES.get(obj_dict["_fw_name"], obj_dict["_fw_name"]) @@ -378,9 +387,10 @@ def load_object(obj_dict): SAVED_FW_MODULES[fw_name] = found_objects[0][1] return found_objects[0][0] if len(found_objects) > 0: - raise ValueError(f"load_object() found multiple objects with cls._fw_name {fw_name} -- {found_objects}") + msg = f"load_object() found multiple objects with cls._fw_name {fw_name} -- {found_objects}" + raise FWSerializationError(msg) - raise ValueError(f"load_object() could not find a class with cls._fw_name {fw_name}") + raise FWSerializationError(f"load_object() could not find a class with cls._fw_name {fw_name}") def load_object_from_file(filename, f_format=None): @@ -391,6 +401,10 @@ def load_object_from_file(filename, f_format=None): filename (str): the filename to load an object from f_format (str): the serialization format (default is auto-detect based on filename extension) + + Raises: + FWFormatError: when file with filename has unsupported format or f_format is invalid + FWSerializationError: when the data read is not a dictionary """ if f_format is None: f_format = filename.split(".")[-1] @@ -401,8 +415,10 @@ def load_object_from_file(filename, f_format=None): elif f_format == "yaml": dct = YAML(typ="safe", pure=True).load(f.read()) else: - raise ValueError(f"Unknown file format {f_format} cannot be loaded!") + raise FWFormatError(f"Unknown file format {f_format} cannot be loaded!") + if not isinstance(dct, dict): + raise FWSerializationError(f"Serialized object must be a dict but is {type(dct)}") classname = FW_NAME_UPDATES.get(dct["_fw_name"], dct["_fw_name"]) if JSON_SCHEMA_VALIDATE and classname in JSON_SCHEMA_VALIDATE_LIST: fireworks_schema.validate(dct, classname) @@ -438,12 +454,12 @@ def reconstitute_dates(obj_dict): if isinstance(obj_dict, str): for method, args in [ - (datetime.datetime.fromisoformat,tuple()), + (datetime.datetime.fromisoformat, tuple()), (datetime.datetime.strptime, ("%Y-%m-%dT%H:%M:%S.%f",)), (datetime.datetime.strptime, ("%Y-%m-%dT%H:%M:%S", )), ]: try: - return method(obj_dict,*args) + return method(obj_dict, *args) except Exception: pass return obj_dict @@ -453,7 +469,7 @@ def get_default_serialization(cls): """Get the default serialization string for a class.""" root_mod = cls.__module__.split(".")[0] if root_mod == "__main__": - raise ValueError( + raise FWSerializationError( "Cannot get default serialization; try " "instantiating your object from a different module " "from which it is defined rather than defining your " diff --git a/fireworks/utilities/tests/test_fw_serializers.py b/fireworks/utilities/tests/test_fw_serializers.py index c06395208..c5800e111 100644 --- a/fireworks/utilities/tests/test_fw_serializers.py +++ b/fireworks/utilities/tests/test_fw_serializers.py @@ -4,12 +4,15 @@ import json import os import unittest +from tempfile import mkdtemp from typing import Any import numpy as np +import pytest from fireworks.user_objects.firetasks.unittest_tasks import ExportTestSerializer, UnitTestSerializer -from fireworks.utilities.fw_serializers import FWSerializable, load_object, recursive_dict +from fireworks.utilities.exceptions import FWFormatError, FWSerializationError +from fireworks.utilities.fw_serializers import FWSerializable, load_object, load_object_from_file, recursive_dict from fireworks.utilities.fw_utilities import explicit_serialize __author__ = "Anubhav Jain" @@ -147,3 +150,70 @@ def setUp(self) -> None: def test_explicit_serialization(self) -> None: assert load_object(self.s_dict) == self.s_obj + + +class FWSerializationErrorTest(unittest.TestCase): + """Test FWSerializationError exception.""" + + def setUp(self): + self.init_dir = os.getcwd() + self.lpad_dir = mkdtemp() + os.chdir(self.lpad_dir) + self.lpad_file = os.path.join(self.lpad_dir, "launchpad.yaml") + with open(self.lpad_file, "w", encoding="utf-8"): + pass + self.msg = "Serialized object must be a dict but is " + + def tearDown(self): + os.chdir(self.init_dir) + os.unlink(self.lpad_file) + os.rmdir(self.lpad_dir) + + def test_load_object_from_file_empty_file(self) -> None: + """Test load_object_from_file with empty file.""" + with pytest.raises(FWSerializationError, match=self.msg): + load_object_from_file(self.lpad_file) + + def test_explicit_serializer_from_file_empty_file(self) -> None: + """Test ExplicitTestSerializer with empty file.""" + with pytest.raises(FWSerializationError, match=self.msg): + ExplicitTestSerializer.from_file(self.lpad_file) + + +class FWFormatErrorTest(unittest.TestCase): + """Test FWFormatError exception.""" + + def setUp(self): + self.init_dir = os.getcwd() + self.lpad_dir = mkdtemp() + os.chdir(self.lpad_dir) + self.lpad_file = os.path.join(self.lpad_dir, "launchpad.txt") + with open(self.lpad_file, "w", encoding="utf-8"): + pass + self.msg1 = "Unsupported format txt" + self.msg2 = "Unknown file format txt cannot be loaded!" + + def tearDown(self): + os.chdir(self.init_dir) + os.unlink(self.lpad_file) + os.rmdir(self.lpad_dir) + + def test_explicit_serializer_from_file(self) -> None: + """Test ExplicitTestSerializer from txt file.""" + with pytest.raises(FWFormatError, match=self.msg1): + ExplicitTestSerializer.from_file(self.lpad_file) + + def test_explicit_serializer_to_file(self) -> None: + """Test ExplicitTestSerializer to txt file.""" + with pytest.raises(FWFormatError, match=self.msg1): + ExplicitTestSerializer(a=1).to_file(self.lpad_file) + + def test_explicit_serializer_to_format(self) -> None: + """Test ExplicitTestSerializer to txt file.""" + with pytest.raises(FWFormatError, match=self.msg1): + ExplicitTestSerializer(a=1).to_format("txt") + + def test_load_object_from_file(self) -> None: + """Test load_object_from_file with txt file.""" + with pytest.raises(FWFormatError, match=self.msg2): + load_object_from_file(self.lpad_file)