diff --git a/emmet-builders/emmet/builders/abinit/__init__.py b/emmet-builders/emmet/builders/abinit/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/emmet-builders/emmet/builders/abinit/phonon.py b/emmet-builders/emmet/builders/abinit/phonon.py deleted file mode 100644 index c14ce8abe6..0000000000 --- a/emmet-builders/emmet/builders/abinit/phonon.py +++ /dev/null @@ -1,862 +0,0 @@ -import os -import tempfile -import warnings -from math import ceil -from typing import TYPE_CHECKING - -import numpy as np -from abipy.abio.inputs import AnaddbInput -from abipy.core.abinit_units import eV_to_THz -from abipy.dfpt.anaddbnc import AnaddbNcFile -from abipy.dfpt.ddb import AnaddbError, DdbFile, DielectricTensorGenerator -from abipy.dfpt.phonons import PhononBands -from abipy.flowtk.tasks import AnaddbTask, TaskManager -from maggma.builders import Builder -from maggma.core import Store -from maggma.utils import grouper -from pymatgen.core.structure import Structure -from pymatgen.io.abinit.abiobjects import KSampling -from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine -from pymatgen.phonon.dos import CompletePhononDos -from pymatgen.phonon.ir_spectra import IRDielectricTensor -from pymatgen.symmetry.analyzer import SpacegroupAnalyzer -from pymatgen.symmetry.bandstructure import HighSymmKpath - -from emmet.builders.settings import EmmetBuildSettings -from emmet.core.phonon import ( - AbinitPhonon, - Ddb, - PhononBS, - PhononDOS, - PhononWarnings, - PhononWebsiteBS, - ThermalDisplacement, - ThermodynamicProperties, - VibrationalEnergy, -) -from emmet.core.polar import BornEffectiveCharges, DielectricDoc, IRDielectric -from emmet.core.utils import jsanitize - -if TYPE_CHECKING: - from collections.abc import Iterator - -SETTINGS = EmmetBuildSettings() - - -warnings.warn( - f"The current version of {__name__}.PhononBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class PhononBuilder(Builder): - def __init__( - self, - phonon_materials: Store, - ddb_source: Store, - phonon: Store, - phonon_bs: Store, - phonon_dos: Store, - ddb_files: Store, - th_disp: Store, - phonon_website: Store, - query: dict | None = None, - manager: TaskManager | None = None, - symprec: float = SETTINGS.SYMPREC, - angle_tolerance: float = SETTINGS.ANGLE_TOL, - chunk_size=100, - **kwargs, - ): - """ - Creates a set of collections for materials generating different kind of data - from the phonon calculations. - The builder requires the execution of the anaddb tool available in abinit. - The parts that may contain large amount of data are split from the main - document and store in separated collections. Notice that in these cases - the size of a single document may be above the 16MB limit allowed by the - standard MongoDB document. - - Args: - phonon_materials (Store): source Store of phonon materials documents - containing abinit_input and abinit_output. - ddb_source (Store): source Store of ddb files. Matching the data in the materials Store. - phonon (Store): target Store of the phonon properties - phonon_bs (Store): target Store for the phonon band structure. The document may - exceed the 16MB limit of a mongodb collection. - phonon_dos (Store): target Store for the phonon DOS. The document may - exceed the 16MB limit of a mongodb collection. - ddb_files (Store): target Store of the DDB files. The document may - exceed the 16MB limit of a mongodb collection. - th_disp (Store): target Store of the data related to the generalized phonon DOS - with the mean square displacement tensor. The document may exceed the 16MB - limit of a mongodb collection. - phonon_website (Store): target Store for the phonon band structure in the phononwebsite - format. The document may exceed the 16MB limit of a mongodb collection. - query (dict): dictionary to limit materials to be analyzed - manager (TaskManager): an instance of the abipy TaskManager. If None it - will be generated from user configuration. - symprec (float): tolerance for symmetry finding when determining the - band structure path. - angle_tolerance (float): angle tolerance for symmetry finding when - determining the band structure path. - """ - - self.phonon_materials = phonon_materials - self.phonon = phonon - self.ddb_source = ddb_source - self.phonon_bs = phonon_bs - self.phonon_dos = phonon_dos - self.ddb_files = ddb_files - self.th_disp = th_disp - self.phonon_website = phonon_website - self.query = query or {} - self.symprec = symprec - self.angle_tolerance = angle_tolerance - self.chunk_size = chunk_size - - if manager is None: - self.manager = TaskManager.from_user_config() - else: - self.manager = manager - - super().__init__( - sources=[phonon_materials, ddb_source], - targets=[phonon, phonon_bs, phonon_dos, ddb_files, th_disp, phonon_website], - chunk_size=chunk_size, - **kwargs, - ) - - def prechunk(self, number_splits: int): # pragma: no cover - """ - Gets all materials that need phonons - - Returns: - generator of materials to extract phonon properties - """ - - # All relevant materials that have been updated since phonon props were last calculated - q = dict(self.query) - - mats = self.phonon.newer_in(self.phonon_materials, exhaustive=True, criteria=q) - - N = ceil(len(mats) / number_splits) - - for mpid_chunk in grouper(mats, N): - yield {"query": {self.phonon_materials.key: {"$in": list(mpid_chunk)}}} - - def get_items(self) -> Iterator[dict]: - """ - Gets all materials that need phonons - - Returns: - generator of materials to extract phonon properties - """ - - self.logger.info("Phonon Builder Started") - - self.logger.info("Setting indexes") - self.ensure_indexes() - - # All relevant materials that have been updated since phonon props were last calculated - q = dict(self.query) - - mats = self.phonon.newer_in(self.phonon_materials, exhaustive=True, criteria=q) - self.logger.info("Found {} new materials for phonon data".format(len(mats))) - - # list of properties queried from the results DB - # basic information - projection = { - "mp_id": 1, - "spacegroup.number": 1, - "abinit_input": 1, # input data - "abinit_output.ddb_id": 1, # file ids to be fetched - } - - for m in mats: - item = self.phonon_materials.query_one( - properties=projection, criteria={self.phonon_materials.key: m} - ) - - # Read the DDB file and pass as an object. Do not write here since in case of parallel - # execution each worker will write its own file. - ddb_data = self.ddb_source.query_one( - criteria={"_id": item["abinit_output"]["ddb_id"]} - ) - if not ddb_data: - self.logger.warning( - f"DDB file not found for file id {item['abinit_output']['ddb_id']}" - ) - continue - - try: - item["ddb_str"] = ddb_data["data"].decode("utf-8") - except Exception: - self.logger.warning( - f"could not extract DDB for file id {item['abinit_output']['ddb_id']}" - ) - continue - - yield item - - def process_item(self, item: dict) -> dict | None: - """ - Generates the full phonon document from an item - - Args: - item (dict): a dict extracted from the phonon calculations results. - - Returns: - dict: a dict with the set of phonon data to be saved in the stores. - """ - self.logger.debug("Processing phonon item for {}".format(item["mp_id"])) - - try: - structure = Structure.from_dict(item["abinit_input"]["structure"]) - - abinit_input_vars = self.abinit_input_vars(item) - phonon_properties = self.get_phonon_properties(item) - sr_break = self.get_sum_rule_breakings(item) - ph_warnings = get_warnings( - sr_break["asr"], sr_break["cnsr"], phonon_properties["ph_bs"] - ) - if PhononWarnings.NEG_FREQ not in ph_warnings: - thermodynamic, vibrational_energy = get_thermodynamic_properties( - phonon_properties["ph_dos"] - ) - else: - thermodynamic, vibrational_energy = None, None - - becs = None - if phonon_properties["becs"] is not None: - becs = BornEffectiveCharges( - symmetrized_value=phonon_properties["becs"], - value=sr_break["becs_nosymm"], - cnsr_break=sr_break["cnsr"], - ) - - ap = AbinitPhonon.from_structure( - structure=structure, - meta_structure=structure, - include_structure=True, - material_id=item["mp_id"], - cnsr_break=sr_break["cnsr"], - asr_break=sr_break["asr"], - warnings=ph_warnings, - dielectric=phonon_properties["dielectric"], - becs=becs, - ir_spectra=phonon_properties["ir_spectra"], - thermodynamic=thermodynamic, - vibrational_energy=vibrational_energy, - abinit_input_vars=abinit_input_vars, - ) - - phbs = PhononBS( - identifier=item["mp_id"], - **phonon_properties["ph_bs"].as_dict(), - ) - - phws = PhononWebsiteBS( # type: ignore[call-arg] - material_id=item["mp_id"], - phononwebsite=phonon_properties["ph_bs"].as_phononwebsite(), - ) - - phdos = PhononDOS( - identifier=item["mp_id"], - **phonon_properties["ph_dos"].as_dict(), - ) - - ddb = Ddb(material_id=item["mp_id"], ddb=item["ddb_str"]) # type: ignore[call-arg] - - th_disp = ThermalDisplacement( # type: ignore[call-arg] - material_id=item["mp_id"], - structure=structure, - nsites=len(structure), - nomega=phonon_properties["th_disp"]["nomega"], - ntemp=phonon_properties["th_disp"]["ntemp"], - temperatures=phonon_properties["th_disp"]["tmesh"].tolist(), - frequencies=phonon_properties["th_disp"]["wmesh"].tolist(), - gdos_aijw=phonon_properties["th_disp"]["gdos_aijw"].tolist(), - amu=phonon_properties["th_disp"]["amu_symbol"], - ucif_t=phonon_properties["th_disp"]["ucif_t"].tolist(), - ucif_string_t300k=phonon_properties["th_disp"]["ucif_string_t300k"], - ) - - self.logger.debug("Item generated for {}".format(item["mp_id"])) - - d = dict( - abiph=jsanitize(ap.model_dump(), allow_bson=True), - phbs=jsanitize(phbs.model_dump(), allow_bson=True), - phws=jsanitize(phws.model_dump(), allow_bson=True), - phdos=jsanitize(phdos.model_dump(), allow_bson=True), - ddb=jsanitize(ddb.model_dump(), allow_bson=True), - th_disp=jsanitize(th_disp.model_dump(), allow_bson=True), - ) - - return d - except Exception as error: - self.logger.warning( - "Error generating the phonon properties for {}: {}".format( - item["mp_id"], error - ) - ) - return None - - def get_phonon_properties(self, item: dict) -> dict: - """ - Extracts the phonon properties from the item - """ - - # the temp dir should still exist when using the objects as some readings are done lazily - with tempfile.TemporaryDirectory() as workdir: - structure = Structure.from_dict(item["abinit_input"]["structure"]) - - self.logger.debug("Running anaddb in {}".format(workdir)) - - ddb_path = os.path.join(workdir, "{}_DDB".format(item["mp_id"])) - with open(ddb_path, "wt") as ddb_file: - ddb_file.write(item["ddb_str"]) - - ddb = DdbFile.from_string(item["ddb_str"]) - has_bec = ddb.has_bec_terms() - has_epsinf = ddb.has_epsinf_terms() - - anaddb_input, labels_list = self.get_properties_anaddb_input( - item, - bs=True, - dos="tetra", - lo_to_splitting=has_bec, - use_dieflag=has_epsinf, - ) - task = self.run_anaddb( - ddb_path=ddb_path, anaddb_input=anaddb_input, workdir=workdir - ) - - with task.open_phbst() as phbst_file, AnaddbNcFile( - task.outpath_from_ext("anaddb.nc") - ) as ananc_file: - # phbst - phbands = phbst_file.phbands - if has_bec: - phbands.read_non_anal_from_file(phbst_file.filepath) - symm_line_bands = self.get_pmg_bs(phbands, labels_list) # type: ignore - - # ananc - if has_bec and ananc_file.becs is not None: - becs = ananc_file.becs.values.tolist() - else: - becs = None - if has_epsinf and ananc_file.epsinf is not None: - e_electronic = ananc_file.epsinf.tolist() - else: - e_electronic = None - e_total = ( - ananc_file.eps0.tolist() if ananc_file.eps0 is not None else None - ) - if e_electronic and e_total: - e_ionic = (ananc_file.eps0 - ananc_file.epsinf).tolist() - dielectric = DielectricDoc.from_ionic_and_electronic( - ionic=e_ionic, - electronic=e_electronic, - material_id=item["mp_id"], - structure=structure, - deprecated=False, - ) - else: - dielectric = None - - # both - if ( - e_electronic - and e_total - and ananc_file.oscillator_strength is not None - ): - die_gen = DielectricTensorGenerator.from_objects( - phbands, ananc_file - ) - - ir_tensor = IRDielectricTensor( - die_gen.oscillator_strength, - die_gen.phfreqs, - die_gen.epsinf, - die_gen.structure, - ).as_dict() - ir_spectra = IRDielectric(ir_dielectric_tensor=ir_tensor) - else: - ir_spectra = None - - dos_method = "tetrahedron" - with task.open_phdos() as phdos_file: - complete_dos = phdos_file.to_pymatgen() - msqd_dos = phdos_file.msqd_dos - - # if the integrated dos is not close enough to the expected value (3*N_sites) rerun the DOS using - # gaussian integration - integrated_dos = phdos_file.phdos.integral()[-1][1] - nmodes = 3 * len(phdos_file.structure) - - if np.abs(integrated_dos - nmodes) / nmodes > 0.01: - self.logger.warning( - "Integrated DOS {} instead of {} for {}. Recalculating with gaussian".format( - integrated_dos, nmodes, item["mp_id"] - ) - ) - with tempfile.TemporaryDirectory() as workdir_dos: - anaddb_input_dos, _ = self.get_properties_anaddb_input( - item, - bs=False, - dos="gauss", - lo_to_splitting=has_bec, - use_dieflag=has_epsinf, - ) - task_dos = self.run_anaddb( - ddb_path=ddb_path, - anaddb_input=anaddb_input_dos, - workdir=workdir_dos, - ) - with task_dos.open_phdos() as phdos_file: - complete_dos = phdos_file.to_pymatgen() - msqd_dos = phdos_file.msqd_dos - dos_method = "gaussian" - - data = { - "ph_dos": complete_dos, - "ph_dos_method": dos_method, - "ph_bs": symm_line_bands, - "becs": becs, - "ir_spectra": ir_spectra, - "dielectric": dielectric, - "th_disp": msqd_dos.get_json_doc(tstart=0, tstop=800, num=161), - } - - return data - - def get_sum_rule_breakings(self, item: dict) -> dict: - """ - Extracts the breaking of the acoustic and charge neutrality sum rules. - Runs anaddb to get the values. - """ - structure = Structure.from_dict(item["abinit_input"]["structure"]) - anaddb_input = AnaddbInput.modes_at_qpoint( - structure, [0, 0, 0], asr=0, chneut=0 - ) - - with tempfile.TemporaryDirectory() as workdir: - ddb_path = os.path.join(workdir, "{}_DDB".format(item["mp_id"])) - with open(ddb_path, "wt") as ddb_file: - ddb_file.write(item["ddb_str"]) - - ddb = DdbFile.from_string(item["ddb_str"]) - has_bec = ddb.has_bec_terms() - - task = self.run_anaddb(ddb_path, anaddb_input, workdir) - - if has_bec: - with AnaddbNcFile(task.outpath_from_ext("anaddb.nc")) as ananc_file: - becs = ananc_file.becs - becs_val = becs.values.tolist() if becs else None - cnsr = np.max(np.abs(becs.sumrule)) if becs else None - else: - becs_val = None - cnsr = None - - with task.open_phbst() as phbst_file: - phbands = phbst_file.phbands - - # If the ASR breaking could not be identified. set it to None to signal the - # missing information. This may trigger a warning. - try: - asr_breaking = phbands.asr_breaking( - units="cm-1", threshold=0.9, raise_on_no_indices=True - ) - asr = asr_breaking.absmax_break - except RuntimeError as e: - self.logger.warning( - "Could not find the ASR breaking for {}. Error: {}".format( - item["mp_id"], e - ) - ) - asr = None - - breakings = {"cnsr": cnsr, "asr": asr, "becs_nosymm": becs_val} - - return breakings - - def run_anaddb( - self, ddb_path: str, anaddb_input: AnaddbInput, workdir: str - ) -> AnaddbTask: - """ - Runs anaddb. Raise AnaddbError if the calculation couldn't complete - - Args: - ddb_path (str): path to the DDB file - anaddb_input (AnaddbInput): the input for anaddb - workdir (str): the directory where the calculation is run - Returns: - An abipy AnaddbTask instance. - """ - - task = AnaddbTask.temp_shell_task( - anaddb_input, ddb_node=ddb_path, workdir=workdir, manager=self.manager - ) - - # Run the task here. - self.logger.debug("Start anaddb for {}".format(ddb_path)) - task.start_and_wait(autoparal=False) - self.logger.debug("Finished anaddb for {}".format(ddb_path)) - - report = task.get_event_report() - if not report.run_completed: - raise AnaddbError(task=task, report=report) - - self.logger.debug("anaddb succesful for {}".format(ddb_path)) - - return task - - def get_properties_anaddb_input( - self, - item: dict, - bs: bool = True, - dos: str = "tetra", - lo_to_splitting: bool = True, - use_dieflag: bool = True, - ) -> tuple[AnaddbInput, list | None]: - """ - creates the AnaddbInput object to calculate the phonon properties. - It also returns the list of qpoints labels for generating the PhononBandStructureSymmLine. - - Args: - item: the item to process - bs (bool): if True the phonon band structure will be calculated - dos (str): if 'tetra' the DOS will be calculated with the tetrahedron method, - if 'gauss' with gaussian smearing, if None the DOS will not be calculated - lo_to_splitting (bool): contributions from the LO-TO splitting for the phonon - BS will be calculated. - use_dieflag (bool): the dielectric tensor will be calculated. - """ - - ngqpt = item["abinit_input"]["ngqpt"] - q1shft = [(0, 0, 0)] - - structure = Structure.from_dict(item["abinit_input"]["structure"]) - - # use all the corrections - dipdip = 1 - asr = 2 - chneut = 1 - - inp = AnaddbInput(structure, comment="ANADB input for phonon bands and DOS") - - inp.set_vars( - ifcflag=1, - ngqpt=np.array(ngqpt), - q1shft=q1shft, - nqshft=len(q1shft), - asr=asr, - chneut=chneut, - dipdip=dipdip, - ) - - # Parameters for the dos. - if dos == "tetra": - # Use tetrahedra with dense dosdeltae (required to get accurate value of the integral) - prtdos = 2 - dosdeltae = 9e-07 # Ha = 2 cm^-1 - ng2qppa = 200000 - ng2qpt = KSampling.automatic_density(structure, kppa=ng2qppa).kpts[0] - inp.set_vars(prtdos=prtdos, dosdeltae=dosdeltae, ng2qpt=ng2qpt) - elif dos == "gauss": - # Use gauss with denser grid and a smearing - prtdos = 1 - dosdeltae = 4.5e-06 # Ha = 10 cm^-1 - ng2qppa = 500000 - dossmear = 1.82e-5 # Ha = 4 cm^-1 - ng2qpt = KSampling.automatic_density(structure, kppa=ng2qppa).kpts[0] - inp.set_vars( - prtdos=prtdos, dosdeltae=dosdeltae, dossmear=dossmear, ng2qpt=ng2qpt - ) - elif dos is not None: - raise ValueError("Unsupported value of dos.") - - # Parameters for the BS - labels_list = None - if bs: - spga = SpacegroupAnalyzer( - structure, symprec=self.symprec, angle_tolerance=self.angle_tolerance - ) - - spgn = spga.get_space_group_number() - if spgn != item["spacegroup"]["number"]: - raise RuntimeError( - "Parsed specegroup number {} does not match " - "calculation spacegroup {}".format( - spgn, item["spacegroup"]["number"] - ) - ) - - hs = HighSymmKpath( - structure, symprec=self.symprec, angle_tolerance=self.angle_tolerance - ) - - qpts, labels_list = hs.get_kpoints( - line_density=18, coords_are_cartesian=False - ) - - n_qpoints = len(qpts) - qph1l = np.zeros((n_qpoints, 4)) - - qph1l[:, :-1] = qpts - qph1l[:, -1] = 1 - - inp["qph1l"] = qph1l.tolist() - inp["nph1l"] = n_qpoints - - if lo_to_splitting: - kpath = hs.kpath - directions = [] # type: list - for qptbounds in kpath["path"]: - for i, qpt in enumerate(qptbounds): - if np.array_equal(kpath["kpoints"][qpt], (0, 0, 0)): - # anaddb expects cartesian coordinates for the qph2l list - if i > 0: - directions.extend( - structure.lattice.reciprocal_lattice_crystallographic.get_cartesian_coords( - kpath["kpoints"][qptbounds[i - 1]] - ) - ) - directions.append(0) - - if i < len(qptbounds) - 1: - directions.extend( - structure.lattice.reciprocal_lattice_crystallographic.get_cartesian_coords( - kpath["kpoints"][qptbounds[i + 1]] - ) - ) - directions.append(0) - - if directions: - directions = np.reshape(directions, (-1, 4)) # type: ignore - inp.set_vars(nph2l=len(directions), qph2l=directions) - - # Parameters for dielectric constant - if use_dieflag: - inp["dieflag"] = 1 - - return inp, labels_list - - @staticmethod - def get_pmg_bs( - phbands: PhononBands, labels_list: list - ) -> PhononBandStructureSymmLine: - """ - Generates a PhononBandStructureSymmLine starting from a abipy PhononBands object - - Args: - phbands (PhononBands): the phonon band structures - labels_list (list): list of labels used to generate the path - Returns: - An instance of PhononBandStructureSymmLine - """ - - structure = phbands.structure - - n_at = len(structure) - - qpts = np.array(phbands.qpoints.frac_coords) - ph_freqs = np.array(phbands.phfreqs) - displ = np.array(phbands.phdispl_cart) - - labels_dict = {} - - for i, (q, l) in enumerate(zip(qpts, labels_list)): - if l: - labels_dict[l] = q - # set LO-TO at gamma - if phbands.non_anal_ph and "Gamma" in l: - if i > 0 and not labels_list[i - 1]: - ph_freqs[i] = phbands._get_non_anal_freqs(qpts[i - 1]) - displ[i] = phbands._get_non_anal_phdispl(qpts[i - 1]) - if i < len(qpts) - 1 and not labels_list[i + 1]: - ph_freqs[i] = phbands._get_non_anal_freqs(qpts[i + 1]) - displ[i] = phbands._get_non_anal_phdispl(qpts[i + 1]) - - ph_freqs = np.transpose(ph_freqs) * eV_to_THz - displ = np.transpose( - np.reshape(displ, (len(qpts), 3 * n_at, n_at, 3)), (1, 0, 2, 3) - ) - - ph_bs_sl = PhononBandStructureSymmLine( - qpoints=qpts, # type: ignore[arg-type] - frequencies=ph_freqs, - lattice=structure.reciprocal_lattice, - has_nac=phbands.non_anal_ph is not None, - eigendisplacements=displ, - labels_dict=labels_dict, - structure=structure, - ) - - ph_bs_sl.band_reorder() - - return ph_bs_sl - - @staticmethod - def abinit_input_vars(item: dict) -> dict: - """ - Extracts the useful abinit input parameters from an item. - """ - - i = item["abinit_input"] - - data = {} - - def get_vars(label): - if label in i and i[label]: - return {k: v for (k, v) in i[label]["abi_args"]} - else: - return {} - - data["gs_input"] = get_vars("gs_input") - data["ddk_input"] = get_vars("ddk_input") - data["dde_input"] = get_vars("dde_input") - data["phonon_input"] = get_vars("phonon_input") - data["wfq_input"] = get_vars("wfq_input") - - data["ngqpt"] = i["ngqpt"] - data["ngkpt"] = i["ngkpt"] - data["shiftk"] = i["shiftk"] - data["ecut"] = i["ecut"] - data["occopt"] = i["occopt"] - data["tsmear"] = i.get("tsmear", 0) - - data["pseudopotentials"] = { - "name": i["pseudopotentials"]["pseudos_name"], - "md5": i["pseudopotentials"]["pseudos_md5"], - } - - return data - - def update_targets(self, items: list[dict]): - """ - Inserts the new task_types into the task_types collection - - Args: - items ([dict]): a list of phonon dictionaries to update - """ - items = list(filter(None, items)) - items_ph = [i["abiph"] for i in items] - items_ph_band = [i["phbs"] for i in items] - items_ph_dos = [i["phdos"] for i in items] - items_ddb = [i["ddb"] for i in items] - items_th_disp = [i["th_disp"] for i in items] - items_ph_web = [i["phws"] for i in items] - - if len(items) > 0: - self.logger.info("Updating {} phonon docs".format(len(items))) - self.phonon.update(docs=items_ph) - self.phonon_bs.update(docs=items_ph_band) - self.phonon_dos.update(docs=items_ph_dos) - self.ddb_files.update(docs=items_ddb) - self.th_disp.update(docs=items_th_disp) - self.phonon_website.update(docs=items_ph_web) - - else: - self.logger.info("No items to update") - - def ensure_indexes(self): - """ - Ensures indexes on the tasks and materials collections - """ - self.phonon_materials.ensure_index(self.phonon_materials.key, unique=True) - - self.phonon.ensure_index(self.phonon.key, unique=True) - self.phonon_bs.ensure_index(self.phonon.key, unique=True) - self.phonon_dos.ensure_index(self.phonon.key, unique=True) - self.ddb_files.ensure_index(self.phonon.key, unique=True) - self.th_disp.ensure_index(self.phonon.key, unique=True) - self.phonon_website.ensure_index(self.phonon.key, unique=True) - - -def get_warnings( - asr_break: float, cnsr_break: float, ph_bs: PhononBandStructureSymmLine -) -> list[PhononWarnings]: - """ - - Args: - asr_break (float): the largest breaking of the acoustic sum rule in cm^-1 - cnsr_break (float): the largest breaking of the charge neutrality sum rule - ph_bs (PhononBandStructureSymmLine): the phonon band structure - - Returns: - PhononWarnings: the model containing the data of the warnings. - """ - - warnings = [] - - if asr_break and asr_break > 30: - warnings.append(PhononWarnings.ASR) - if cnsr_break and cnsr_break > 0.2: - warnings.append(PhononWarnings.CNSR) - - # neglect small negative frequencies (0.03 THz ~ 1 cm^-1) - limit = -0.03 - - bands = np.array(ph_bs.bands) - neg_freq = bands < limit - - # there are negative frequencies anywhere in the BZ - if np.any(neg_freq): - warnings.append(PhononWarnings.NEG_FREQ) - - qpoints = np.array([q.frac_coords for q in ph_bs.qpoints]) - - qpt_has_neg_freq = np.any(neg_freq, axis=0) - - if np.max(np.linalg.norm(qpoints[qpt_has_neg_freq], axis=1)) < 0.05: - warnings.append(PhononWarnings.SMALL_Q_NEG_FREQ) - - return warnings - - -def get_thermodynamic_properties( - ph_dos: CompletePhononDos, -) -> tuple[ThermodynamicProperties, VibrationalEnergy]: - """ - Calculates the thermodynamic properties from a phonon DOS - - Args: - ph_dos (CompletePhononDos): The DOS used to calculate the properties. - - Returns: - ThermodynamicProperties and VibrationalEnergy: the models containing the calculated thermodynamic - properties and vibrational contribution to the total energy. - """ - - tstart, tstop, nt = 0, 800, 161 - temp = np.linspace(tstart, tstop, nt) - - cv = [] - entropy = [] - internal_energy = [] - helmholtz_free_energy = [] - - for t in temp: - cv.append(ph_dos.cv(t, ph_dos.structure)) - entropy.append(ph_dos.entropy(t, ph_dos.structure)) - internal_energy.append(ph_dos.internal_energy(t, ph_dos.structure)) - helmholtz_free_energy.append(ph_dos.helmholtz_free_energy(t, ph_dos.structure)) - - zpe = ph_dos.zero_point_energy(ph_dos.structure) - - temperatures = temp.tolist() - tp = ThermodynamicProperties(temperatures=temperatures, cv=cv, entropy=entropy) # type: ignore[call-arg] - - ve = VibrationalEnergy( - temperatures=temperatures, - internal_energy=internal_energy, - helmholtz_free_energy=helmholtz_free_energy, - zero_point_energy=zpe, - ) - - return tp, ve diff --git a/emmet-builders/emmet/builders/abinit/sound_velocity.py b/emmet-builders/emmet/builders/abinit/sound_velocity.py deleted file mode 100644 index df689d3753..0000000000 --- a/emmet-builders/emmet/builders/abinit/sound_velocity.py +++ /dev/null @@ -1,226 +0,0 @@ -import tempfile -import traceback -import warnings -from math import ceil -from typing import TYPE_CHECKING - -from abipy.dfpt.ddb import DdbFile -from abipy.dfpt.vsound import SoundVelocity as AbiSoundVelocity -from abipy.flowtk.tasks import TaskManager -from maggma.builders import Builder -from maggma.core import Store -from maggma.utils import grouper - -from emmet.core.phonon import SoundVelocity -from emmet.core.utils import jsanitize - -if TYPE_CHECKING: - from collections.abc import Iterator - - -warnings.warn( - f"The current version of {__name__}.SoundVelocityBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class SoundVelocityBuilder(Builder): - def __init__( - self, - phonon_materials: Store, - ddb_source: Store, - sound_vel: Store, - query: dict | None = None, - manager: TaskManager | None = None, - **kwargs, - ): - """ - Creates a collection with the data of the sound velocities extracted from - the phonon calculations. - - Args: - phonon_materials (Store): source Store of phonon materials documents - containing abinit_input and abinit_output. - ddb_source (Store): source Store of ddb files. Matching the data in the materials Store. - sound_vel (Store): target Store of the sound velocity - query (dict): dictionary to limit materials to be analyzed - manager (TaskManager): an instance of the abipy TaskManager. If None it will be - generated from user configuration. - """ - - self.phonon_materials = phonon_materials - self.ddb_source = ddb_source - self.sound_vel = sound_vel - self.query = query or {} - - if manager is None: - self.manager = TaskManager.from_user_config() - else: - self.manager = manager - - super().__init__( - sources=[phonon_materials, ddb_source], targets=[sound_vel], **kwargs - ) - - def prechunk(self, number_splits: int): # pragma: no cover - """ - Gets all materials that need sound velocity - - Returns: - generator of materials to extract phonon sound velocity - """ - - # All relevant materials that have been updated since phonon props were last calculated - q = dict(self.query) - - mats = self.sound_vel.newer_in( - self.phonon_materials, exhaustive=True, criteria=q - ) - - N = ceil(len(mats) / number_splits) - - for mpid_chunk in grouper(mats, N): - yield {"query": {self.phonon_materials.key: {"$in": list(mpid_chunk)}}} - - def get_items(self) -> Iterator[dict]: - """ - Gets all materials that need sound velocity. - - Returns: - generator of materials to extract the sound velocity - """ - - self.logger.info("Sound Velocity Builder Started") - - self.logger.info("Setting indexes") - self.ensure_indexes() - - # All relevant materials that have been updated since sound velocities were last calculated - q = dict(self.query) - mats = self.sound_vel.newer_in( - self.phonon_materials, exhaustive=True, criteria=q - ) - self.logger.info( - "Found {} new materials for sound velocity data".format(len(mats)) - ) - - # list of properties queried from the results DB - # basic informations - projection = {"mp_id": 1} - # input data - projection["abinit_input"] = 1 - # file ids to be fetched - projection["abinit_output.ddb_id"] = 1 - - for m in mats: - item = self.phonon_materials.query_one( - properties=projection, criteria={self.phonon_materials.key: m} - ) - - # Read the DDB file and pass as an object. Do not write here since in case of parallel - # execution each worker will write its own file. - ddb_data = self.ddb_source.query_one( - criteria={"_id": item["abinit_output"]["ddb_id"]} - ) - - item["ddb_str"] = ddb_data["data"].decode("utf-8") - - yield item - - def process_item(self, item: dict) -> dict | None: - """ - Generates the sound velocity document from an item - - Args: - item (dict): a dict extracted from the phonon calculations results. - - Returns: - dict: a dict with phonon data - """ - self.logger.debug("Processing sound velocity item for {}".format(item["mp_id"])) - - try: - sound_vel_data = self.get_sound_vel(item) - - sv = SoundVelocity( # type: ignore[call-arg] - material_id=item["mp_id"], - structure=sound_vel_data["structure"], - directions=sound_vel_data["directions"], - labels=sound_vel_data["labels"], - sound_velocities=sound_vel_data["sound_velocities"], - mode_types=sound_vel_data["mode_types"], - ) - - self.logger.debug("Item generated for {}".format(item["mp_id"])) - - return jsanitize(sv.model_dump()) - except Exception: - self.logger.warning( - "Error generating the sound velocity for {}: {}".format( - item["mp_id"], traceback.format_exc() - ) - ) - return None - - @staticmethod - def get_sound_vel(item: dict) -> dict: - """ - Runs anaddb and return the extracted data for the speed of sound. - - Args: - item (dict): the item to process - Returns: - A dictionary with the sound velocity values - """ - with tempfile.NamedTemporaryFile( - mode="wt", suffix="_DDB", delete=True - ) as ddb_file: - ddb_file.write(item["ddb_str"]) - ngqpt = item["abinit_input"]["ngqpt"] - sv = AbiSoundVelocity.from_ddb( - ddb_file.name, - ngqpt=ngqpt, - num_points=20, - qpt_norm=0.1, - ignore_neg_freqs=True, - directions=None, - ) - - ddb = DdbFile.from_string(item["ddb_str"]) - sv_data = dict( - directions=sv.directions.tolist(), - sound_velocities=sv.sound_velocities.tolist(), - mode_types=sv.mode_types, - labels=sv.labels, - structure=ddb.structure, - ) - - return sv_data - - def update_targets(self, items: list[dict]): - """ - Inserts the new task_types into the task_types collection - - Args: - items ([dict]): a list of dictionaries with sound velocities - to update. - """ - self.logger.debug("Start update_targets") - items = list(filter(None, items)) - - if len(items) > 0: - self.logger.info("Updating {} sound velocity docs".format(len(items))) - self.sound_vel.update(docs=items) - else: - self.logger.info("No items to update") - - def ensure_indexes(self): - """ - Ensures indexes on the sound_vel collection. - """ - - # Search index for sound velocity - self.sound_vel.ensure_index(self.sound_vel.key, unique=True) diff --git a/emmet-builders/emmet/builders/base.py b/emmet-builders/emmet/builders/base.py new file mode 100644 index 0000000000..dc8ce97179 --- /dev/null +++ b/emmet-builders/emmet/builders/base.py @@ -0,0 +1,20 @@ +from pydantic import Field + +from emmet.core.base import EmmetBaseModel +from emmet.core.types.pymatgen_types.structure_adapter import StructureType +from emmet.core.types.typing import IdentifierType + + +class BaseBuilderInput(EmmetBaseModel): + """ + Document model with the minimum inputs required + to run builders that only require a Pymatgen structure + object for property analysis. + + A material_id and builder_meta information may be optionally + included. + """ + + deprecated: bool = Field(False) + material_id: IdentifierType | None = Field(None) + structure: StructureType diff --git a/emmet-builders/emmet/builders/feff/__init__.py b/emmet-builders/emmet/builders/feff/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/emmet-builders/emmet/builders/feff/xas.py b/emmet-builders/emmet/builders/feff/xas.py deleted file mode 100644 index c76c2c09ec..0000000000 --- a/emmet-builders/emmet/builders/feff/xas.py +++ /dev/null @@ -1,77 +0,0 @@ -import traceback -import warnings -from datetime import datetime -from itertools import chain - -from maggma.builders import GroupBuilder -from maggma.core import Store - -from emmet.core.feff.task import TaskDocument as FEFFTaskDocument -from emmet.core.utils import jsanitize -from emmet.core.xas import XASDoc - -warnings.warn( - f"The current version of {__name__}.XASBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class XASBuilder(GroupBuilder): - """ - Generates XAS Docs from FEFF tasks - - # TODO: Generate MPID from materials collection rather than from task metadata - """ - - def __init__(self, tasks: Store, xas: Store, num_samples: int = 200, **kwargs): - - self.tasks = tasks - self.xas = xas - self.num_samples = num_samples - self.kwargs = kwargs - - super().__init__(source=tasks, target=xas, grouping_keys=["mp_id"]) - self._target_keys_field = "xas_ids" - - def process_item(self, spectra: list[dict]) -> dict: - # TODO: Change this to do structure matching against materials collection - mpid = spectra[0]["mp_id"] - - self.logger.debug(f"Processing: {mpid}") - - tasks = [FEFFTaskDocument(**task) for task in spectra] - - try: - docs = XASDoc.from_task_docs(tasks, material_id=mpid) - processed = [d.model_dump() for d in docs] - - for d in processed: - d.update({"state": "successful"}) - except Exception as e: - self.logger.error(traceback.format_exc()) - processed = [ - { - "error": str(e), - "state": "failed", - "task_ids": list(d.task_id for d in tasks), - } - ] - - update_doc = { - "_bt": datetime.utcnow(), - } - for d in processed: - d.update({k: v for k, v in update_doc.items() if k not in d}) - - return jsanitize(processed, allow_bson=True) - - def update_targets(self, items): - """ - Group buidler isn't designed for many-to-many so we unwrap that here - """ - - items = list(filter(None.__ne__, chain.from_iterable(items))) - super().update_targets(items) diff --git a/emmet-builders/emmet/builders/materials/absorption_spectrum.py b/emmet-builders/emmet/builders/materials/absorption_spectrum.py index e1c24b879f..b2517c20ce 100644 --- a/emmet-builders/emmet/builders/materials/absorption_spectrum.py +++ b/emmet-builders/emmet/builders/materials/absorption_spectrum.py @@ -1,224 +1,59 @@ -from __future__ import annotations - -import warnings -from math import ceil -from typing import TYPE_CHECKING - -import numpy as np -from maggma.builders import Builder -from maggma.core import Store -from maggma.utils import grouper -from pymatgen.core.structure import Structure - +from emmet.builders.base import BaseBuilderInput +from emmet.builders.utils import filter_map from emmet.core.absorption import AbsorptionDoc -from emmet.core.utils import jsanitize - -if TYPE_CHECKING: - from collections.abc import Iterator - - -warnings.warn( - f"The current version of {__name__}.AbsorptionBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class AbsorptionBuilder(Builder): - def __init__( - self, - materials: Store, - tasks: Store, - absorption: Store, - query: dict | None = None, - **kwargs, - ): - self.materials = materials - self.tasks = tasks - self.absorption = absorption - self.query = query or {} - self.kwargs = kwargs - - self.materials.key = "material_id" - self.tasks.key = "task_id" - self.absorption.key = "material_id" - - super().__init__(sources=[materials, tasks], targets=[absorption], **kwargs) - - def prechunk(self, number_splits: int) -> Iterator[dict]: # pragma: no cover - """ - Prechunk method to perform chunking by the key field - """ - q = dict(self.query) - - keys = self.absorption.newer_in(self.materials, criteria=q, exhaustive=True) - - N = ceil(len(keys) / number_splits) - for split in grouper(keys, N): - yield {"query": {self.materials.key: {"$in": list(split)}}} - - def get_items(self) -> Iterator[list[dict]]: - """ - Gets all items to process - - Returns: - generator or list relevant tasks and materials to process - """ - - self.logger.info("Absorption Builder Started") - - q = dict(self.query) - - mat_ids = self.materials.distinct(self.materials.key, criteria=q) - ab_ids = self.absorption.distinct(self.absorption.key) - - mats_set = set( - self.absorption.newer_in(target=self.materials, criteria=q, exhaustive=True) - ) | (set(mat_ids) - set(ab_ids)) - - mats = [mat for mat in mats_set] - - self.logger.info( - "Processing {} materials for absorption data".format(len(mats)) - ) - - self.total = len(mats) - - for mat in mats: - doc = self._get_processed_doc(mat) - - if doc is not None: - yield doc - else: - pass - - def process_item(self, item): - structure = Structure.from_dict(item["structure"]) - mpid = item[self.materials.key] - origin_entry = {"name": "absorption", "task_id": item["task_id"]} - - doc = AbsorptionDoc.from_structure( - structure=structure, - material_id=mpid, - task_id=item["task_id"], - deprecated=False, - energies=item["energies"], - real_d=item["real_dielectric"], - imag_d=item["imag_dielectric"], - absorption_co=item["optical_absorption_coeff"], - bandgap=item["bandgap"], - nkpoints=item["nkpoints"], - last_updated=item["updated_on"], - origins=[origin_entry], - ) - - return jsanitize(doc.model_dump(), allow_bson=True) - - def update_targets(self, items): - """ - Inserts the new absorption docs into the absorption collection - """ - docs = list(filter(None, items)) - - if len(docs) > 0: - self.logger.info(f"Found {len(docs)} absorption docs to update") - self.absorption.update(docs) - else: - self.logger.info("No items to update") - - def _get_processed_doc(self, mat): - mat_doc = self.materials.query_one( - {self.materials.key: mat}, - [ - self.materials.key, - "structure", - "task_types", - "run_types", +from emmet.core.material import PropertyOrigin +from emmet.core.types.typing import DateTimeType + + +class AbsorptionBuilderInput(BaseBuilderInput): + energies: list[float] + real_d: list[float] + imag_d: list[float] + absorption_co: list[float] + bandgap: float | None + nkpoints: int | None + last_updated: DateTimeType + origins: list[PropertyOrigin] + + +def build_absorption_docs( + input_documents: list[AbsorptionBuilderInput], **kwargs +) -> list[AbsorptionDoc]: + """ + Generate absorption documents from input structures. + + Transforms a list of AbsorptionBuilderInput documents containing + Pymatgen structures into corresponding AbsorbtionDoc instances by + generating an absorption spectrum based on frequency dependent + dielectric function outputs. + + Caller is responsible for creating AbsorptionBuilderInput instances + within their data pipeline context. + + Args: + input_documents: List of AbsorptionBuilderInput documents to process. + + Returns: + list[AbsorbtionDoc] + """ + return list( + filter_map( + AbsorptionDoc.from_structure, + input_documents, + work_keys=[ + "energies", + "real_d", + "imag_d", + "absorption_co", + "bandgap", + "nkpoints", "last_updated", + "origins", + # PropertyDoc.from_structure(...) kwargs + "deprecated", + "material_id", + "structure", ], + **kwargs ) - - task_types = mat_doc["task_types"].items() - - potential_task_ids = [] - - for task_id, task_type in task_types: - if task_type == "Optic": - potential_task_ids.append(task_id) - - final_docs = [] - - for task_id in potential_task_ids: - task_query = self.tasks.query_one( - properties=[ - "orig_inputs.kpoints", - "orig_inputs.structure", - "input.parameters", - "input.structure", - "output.dielectric.energy", - "output.dielectric.real", - "output.dielectric.imag", - "calcs_reversed", - "output.bandgap", - ], - criteria={self.tasks.key: task_id}, - ) - - if (cr := task_query.get("calcs_reversed", [])) and ( - oac := cr[0]["output"]["optical_absorption_coeff"] - ): - try: - structure = task_query["input"]["structure"] - except KeyError: - structure = task_query["orig_inputs"]["structure"] - - if ( - task_query["orig_inputs"]["kpoints"]["generation_style"] - == "Monkhorst" - or task_query["orig_inputs"]["kpoints"]["generation_style"] - == "Gamma" - ): - nkpoints = np.prod( - task_query["orig_inputs"]["kpoints"]["kpoints"][0], axis=0 - ) - - else: - nkpoints = task_query["orig_inputs"]["kpoints"]["nkpoints"] - - lu_dt = mat_doc["last_updated"] - - final_docs.append( - { - "task_id": task_id, - "nkpoints": int(nkpoints), - "energies": cr[0]["output"]["frequency_dependent_dielectric"][ - "energy" - ], - "real_dielectric": cr[0]["output"][ - "frequency_dependent_dielectric" - ]["real"], - "imag_dielectric": cr[0]["output"][ - "frequency_dependent_dielectric" - ]["imaginary"], - "optical_absorption_coeff": oac, - "bandgap": task_query["output"]["bandgap"], - "structure": structure, - "updated_on": lu_dt, - self.materials.key: mat_doc[self.materials.key], - } - ) - - if len(final_docs) > 0: - sorted_final_docs = sorted( - final_docs, - key=lambda entry: ( - entry["nkpoints"], - entry["updated_on"], - ), - reverse=True, - ) - return sorted_final_docs[0] - else: - return None + ) diff --git a/emmet-builders/emmet/builders/materials/alloys.py b/emmet-builders/emmet/builders/materials/alloys.py deleted file mode 100644 index f13ab7e108..0000000000 --- a/emmet-builders/emmet/builders/materials/alloys.py +++ /dev/null @@ -1,394 +0,0 @@ -import warnings -from itertools import chain, combinations - -from maggma.builders import Builder -from matminer.datasets import load_dataset -from pymatgen.analysis.alloys.core import ( - KNOWN_ANON_FORMULAS, - AlloyMember, - AlloyPair, - AlloySystem, - InvalidAlloy, -) -from pymatgen.core.structure import Structure -from tqdm import tqdm - -from emmet.core.types.enums import ThermoType - -# rough sort of ANON_FORMULAS by "complexity" -ANON_FORMULAS = sorted(KNOWN_ANON_FORMULAS, key=lambda af: len(af)) - -# Combinatorially, cannot StructureMatch every single possible pair of materials -# Use a loose spacegroup for a pre-screen (in addition to standard spacegroup) -LOOSE_SPACEGROUP_SYMPREC = 0.5 - -# A source of effective masses, should be replaced with MP-provided effective masses. -BOLTZTRAP_DF = load_dataset("boltztrap_mp") - -warnings.warn( - f"The current versions AlloyPairBuilder, AlloyPairMemberBuilder, and AlloySystemBuilder in {__name__} " - "will be deprecated in version 0.87.0. To continue using legacy builders please install " - "emmet-builders-legacy from git. A PyPI release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class AlloyPairBuilder(Builder): - """ - This builder iterates over anonymous_formula and builds AlloyPair. - It does not look for members of an AlloyPair. - """ - - def __init__( - self, - materials, - thermo, - electronic_structure, - provenance, - oxi_states, - alloy_pairs, - thermo_type: ThermoType | str = ThermoType.GGA_GGA_U_R2SCAN, - ): - self.materials = materials - self.thermo = thermo - self.electronic_structure = electronic_structure - self.provenance = provenance - self.oxi_states = oxi_states - self.alloy_pairs = alloy_pairs - - t_type = thermo_type if isinstance(thermo_type, str) else thermo_type.value - valid_types = {*map(str, ThermoType.__members__.values())} - if invalid_types := {t_type} - valid_types: - raise ValueError( - f"Invalid thermo type(s) passed: {invalid_types}, valid types are: {valid_types}" - ) - - self.thermo_type = t_type - - super().__init__( - sources=[materials, thermo, electronic_structure, provenance, oxi_states], - targets=[alloy_pairs], - chunk_size=8, - ) - - def ensure_indexes(self): - self.alloy_pairs.ensure_index("pair_id") - self.alloy_pairs.ensure_index("_search.id") - self.alloy_pairs.ensure_index("_search.formula") - self.alloy_pairs.ensure_index("_search.member_ids") - self.alloy_pairs.ensure_index("alloy_pair.chemsys") - - def get_items(self): - self.ensure_indexes() - - for idx, af in enumerate(ANON_FORMULAS): - # if af != "AB": - # continue - - thermo_docs = self.thermo.query( - criteria={ - "formula_anonymous": af, - "deprecated": False, - "thermo_type": self.thermo_type, - }, - properties=[ - "material_id", - "energy_above_hull", - "formation_energy_per_atom", - ], - ) - - thermo_docs = {d["material_id"]: d for d in thermo_docs} - - mpids = list(thermo_docs.keys()) - - docs = self.materials.query( - criteria={ - "material_id": {"$in": mpids}, - "deprecated": False, - }, # , "material_id": {"$in": ["mp-804", "mp-661"]}}, - properties=["structure", "material_id", "symmetry.number"], - ) - docs = {d["material_id"]: d for d in docs} - - electronic_structure_docs = self.electronic_structure.query( - {"material_id": {"$in": mpids}}, - properties=["material_id", "band_gap", "is_gap_direct"], - ) - electronic_structure_docs = { - d["material_id"]: d for d in electronic_structure_docs - } - - provenance_docs = self.provenance.query( - {"material_id": {"$in": mpids}}, - properties=["material_id", "theoretical", "database_IDs"], - ) - provenance_docs = {d["material_id"]: d for d in provenance_docs} - - oxi_states_docs = self.oxi_states.query( - {"material_id": {"$in": mpids}, "state": "successful"}, - properties=["material_id", "structure"], - ) - oxi_states_docs = {d["material_id"]: d for d in oxi_states_docs} - - for material_id in mpids: - d = docs[material_id] - - d["structure"] = Structure.from_dict(d["structure"]) - - if material_id in oxi_states_docs: - d["structure_oxi"] = Structure.from_dict( - oxi_states_docs[material_id]["structure"] - ) - else: - d["structure_oxi"] = d["structure"] - - # calculate loose space group - d["spacegroup_loose"] = d["structure"].get_space_group_info( - LOOSE_SPACEGROUP_SYMPREC - )[1] - - d["properties"] = {} - # patch in BoltzTraP data if present - row = BOLTZTRAP_DF.loc[BOLTZTRAP_DF["mpid"] == material_id] - if len(row) == 1: - d["properties"]["m_n"] = float(row.m_n) - d["properties"]["m_p"] = float(row.m_p) - - if material_id in electronic_structure_docs: - for key in ("band_gap", "is_gap_direct"): - d["properties"][key] = electronic_structure_docs[material_id][ - key - ] - - for key in ("energy_above_hull", "formation_energy_per_atom"): - d["properties"][key] = thermo_docs[material_id][key] - - if material_id in provenance_docs: - for key in ("theoretical",): - d["properties"][key] = provenance_docs[material_id][key] - - print( - f"Starting {af} with {len(docs)} materials, anonymous formula {idx} of {len(ANON_FORMULAS)}" - ) - - yield docs - - def process_item(self, item): - pairs = [] - for mpids in tqdm(list(combinations(item.keys(), 2))): - if ( - item[mpids[0]]["symmetry"]["number"] - == item[mpids[1]]["symmetry"]["number"] - ) or ( - item[mpids[0]]["spacegroup_loose"] == item[mpids[1]]["spacegroup_loose"] - ): - # optionally, could restrict based on band gap too (e.g. at least one end-point semiconducting) - # if (item[mpids[0]]["band_gap"] > 0) or (item[mpids[1]]["band_gap"] > 0): - try: - pair = AlloyPair.from_structures( - structures=[ - item[mpids[0]]["structure"], - item[mpids[1]]["structure"], - ], - structures_with_oxidation_states=[ - item[mpids[0]]["structure_oxi"], - item[mpids[1]]["structure_oxi"], - ], - ids=[mpids[0], mpids[1]], - properties=[ - item[mpids[0]]["properties"], - item[mpids[1]]["properties"], - ], - ) - pairs.append( - { - "alloy_pair": pair.as_dict(), - "_search": pair.search_dict(), - "pair_id": pair.pair_id, - } - ) - except InvalidAlloy: - pass - except Exception as exc: - print(exc) - - if pairs: - print(f"Found {len(pairs)} alloy(s)") - - return pairs - - def update_targets(self, items): - docs = list(chain.from_iterable(items)) - if docs: - self.alloy_pairs.update(docs) - - -class AlloyPairMemberBuilder(Builder): - """ - This builder iterates over available AlloyPairs by chemical system - and searches for possible members of those AlloyPairs. - """ - - def __init__(self, alloy_pairs, materials, snls, alloy_pair_members): - self.alloy_pairs = alloy_pairs - self.materials = materials - self.snls = snls - self.alloy_pair_members = alloy_pair_members - - super().__init__( - sources=[alloy_pairs, materials, snls], targets=[alloy_pair_members] - ) - - def ensure_indexes(self): - self.alloy_pairs.ensure_index("pair_id") - self.alloy_pairs.ensure_index("_search.id") - self.alloy_pairs.ensure_index("_search.formula") - self.alloy_pairs.ensure_index("_search.member_ids") - self.alloy_pairs.ensure_index("alloy_pair.chemsys") - self.alloy_pairs.ensure_index("alloy_pair.anonymous_formula") - - def get_items(self): - all_alloy_chemsys = set(self.alloy_pairs.distinct("alloy_pair.chemsys")) - all_known_chemsys = set(self.materials.distinct("chemsys")) | set( - self.snls.distinct("chemsys") - ) - possible_chemsys = all_known_chemsys.intersection(all_alloy_chemsys) - - print( - f"There are {len(all_alloy_chemsys)} alloy chemical systems of which " - f"{len(possible_chemsys)} may have members." - ) - - for idx, chemsys in enumerate(possible_chemsys): - pairs = self.alloy_pairs.query(criteria={"alloy_pair.chemsys": chemsys}) - pairs = [AlloyPair.from_dict(d["alloy_pair"]) for d in pairs] - - mp_docs = self.materials.query( - criteria={"chemsys": chemsys, "deprecated": False}, - properties=["structure", "material_id"], - ) - mp_structures = { - d["material_id"]: Structure.from_dict(d["structure"]) for d in mp_docs - } - - snl_docs = self.snls.query({"chemsys": chemsys}) - snl_structures = {d["snl_id"]: Structure.from_dict(d) for d in snl_docs} - - structures = mp_structures - structures.update(snl_structures) - - if structures: - yield (pairs, structures) - - def process_item(self, item: tuple[list[AlloyPair], dict[str, Structure]]): - pairs, structures = item - - all_pair_members = [] - for pair in pairs: - pair_members = {"pair_id": pair.pair_id, "members": []} - for db_id, structure in structures.items(): - try: - if pair.is_member(structure): - db, _ = db_id.split("-") - member = AlloyMember( - id_=db_id, - db=db, - composition=structure.composition, - is_ordered=structure.is_ordered, - x=pair.get_x(structure.composition), - ) - pair_members["members"].append(member.as_dict()) # type: ignore[attr-defined] - except Exception as exc: - print(f"Exception for {db_id}: {exc}") - if pair_members["members"]: - all_pair_members.append(pair_members) - - return all_pair_members - - def update_targets(self, items): - docs = list(chain.from_iterable(items)) - if docs: - self.alloy_pair_members.update(docs) - - -class AlloySystemBuilder(Builder): - """ - This builder stitches together the results of - AlloyPairBuilder and AlloyPairMemberBuilder. The output - of this collection is the one served by the AlloyPair API. - It also builds AlloySystem. - """ - - def __init__( - self, alloy_pairs, alloy_pair_members, alloy_pairs_merged, alloy_systems - ): - self.alloy_pairs = alloy_pairs - self.alloy_pair_members = alloy_pair_members - self.alloy_pairs_merged = alloy_pairs_merged - self.alloy_systems = alloy_systems - - super().__init__( - sources=[alloy_pairs, alloy_pair_members], - targets=[alloy_pairs_merged, alloy_systems], - chunk_size=8, - ) - - def get_items(self): - for idx, af in enumerate(ANON_FORMULAS): - # comment out to only calculate a single anonymous formula for debugging - # if af != "AB": - # continue - - docs = list(self.alloy_pairs.query({"alloy_pair.anonymous_formula": af})) - pair_ids = [d["pair_id"] for d in docs] - members = { - d["pair_id"]: d - for d in self.alloy_pair_members.query({"pair_id": {"$in": pair_ids}}) - } - - if docs: - yield docs, members - - def process_item(self, item): - pair_docs, members = item - - for doc in pair_docs: - if doc["pair_id"] in members: - doc["alloy_pair"]["members"] = members[doc["pair_id"]]["members"] - doc["_search"]["member_ids"] = [ - m["id_"] for m in members[doc["pair_id"]]["members"] - ] - else: - doc["alloy_pair"]["members"] = [] - doc["_search"]["member_ids"] = [] - - pairs = [AlloyPair.from_dict(d["alloy_pair"]) for d in pair_docs] - systems = AlloySystem.systems_from_pairs(pairs) - - system_docs = [ - { - "alloy_system": system.as_dict(), - "alloy_id": system.alloy_id, - "_search": {"member_ids": [m.id_ for m in system.members]}, - } - for system in systems - ] - - for system_doc in system_docs: - # Too big to store, will need to reconstruct separately from pair_ids - system_doc["alloy_system"]["alloy_pairs"] = None - - return pair_docs, system_docs - - def update_targets(self, items): - pair_docs, system_docs = [p for p, s in items], [s for p, s in items] - - pair_docs = list(chain.from_iterable(pair_docs)) - if pair_docs: - self.alloy_pairs_merged._collection.insert_many(pair_docs) - - system_docs = list(chain.from_iterable(system_docs)) - if system_docs: - self.alloy_systems._collection.insert_many(system_docs) diff --git a/emmet-builders/emmet/builders/materials/basic_descriptors.py b/emmet-builders/emmet/builders/materials/basic_descriptors.py deleted file mode 100644 index 32e0f3749d..0000000000 --- a/emmet-builders/emmet/builders/materials/basic_descriptors.py +++ /dev/null @@ -1,175 +0,0 @@ -import warnings - -import numpy as np -from maggma.builders import MapBuilder -from matminer.featurizers.composition import ElementProperty -from matminer.featurizers.site import CoordinationNumber, CrystalNNFingerprint -from pymatgen.analysis import local_env -from pymatgen.core.structure import Structure - -from emmet.core.structure import StructureMetadata - -# TODO: -# 1) ADD DOCUMENT MODEL -# 2) Add checking OPs present in current implementation of site fingerprints. -# 3) Complete documentation!!! - - -__author__ = "Nils E. R. Zimmermann " - -nn_target_classes = [ - "MinimumDistanceNN", - "VoronoiNN", - "CrystalNN", - "JmolNN", - "MinimumOKeeffeNN", - "MinimumVIRENN", - "BrunnerNN_reciprocal", - "BrunnerNN_relative", - "BrunnerNN_real", - "EconNN", -] - -warnings.warn( - f"The current version of {__name__}.BasicDescriptorsBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class BasicDescriptorsBuilder(MapBuilder): - def __init__(self, materials, descriptors, **kwargs): - """ - Calculates site-based descriptors (e.g., coordination numbers - with different near-neighbor finding approaches) for materials and - runs statistics analysis on selected descriptor types - (order parameter-based site fingerprints). The latter is - useful as a definition of a structure fingerprint - on the basis of local coordination information. - Furthermore, composition descriptors are calculated - (Magpie element property vector). - - Args: - materials (Store): Store of materials documents. - descriptors (Store): Store of composition, site, and - structure descriptor data such - as tetrahedral order parameter or - fraction of being 8-fold coordinated. - mat_query (dict): dictionary to limit materials to be analyzed. - """ - - self.materials = materials - self.descriptors = descriptors - - # Set up all targeted site descriptors. - self.sds = {} - for nn in nn_target_classes: - nn_ = getattr(local_env, nn) - k = "cn_{}".format(nn) - self.sds[k] = CoordinationNumber(nn_(), use_weights="none") - k = "cn_wt_{}".format(nn) - self.sds[k] = CoordinationNumber(nn_(), use_weights="sum") - self.all_output_pieces = {"site_descriptors": [k for k in self.sds.keys()]} - self.sds["csf"] = CrystalNNFingerprint.from_preset( - "ops", distance_cutoffs=None, x_diff_weight=None - ) - self.all_output_pieces["statistics"] = ["csf"] - - # Set up all targeted composition descriptors. - self.cds = {} - self.cds["magpie"] = ElementProperty.from_preset("magpie") - self.all_output_pieces["composition_descriptors"] = ["magpie"] - - self.all_output_pieces["meta"] = ["atomate"] - - super().__init__( - source=materials, target=descriptors, projection=["structure"], **kwargs - ) - - def unary_function(self, item): - """ - Calculates all basic descriptors for the structures - - - Args: - item (dict): a dict with a task_id and a structure - - Returns: - dict: a basic-descriptors dict - """ - self.logger.debug( - "Calculating basic descriptors for {}".format(item[self.materials.key]) - ) - - struct = Structure.from_dict(item["structure"]) - - descr_doc = {"structure": struct.copy()} - descr_doc["meta"] = StructureMetadata.from_structure(struct) - try: - comp_descr = [{"name": "magpie"}] - labels = self.cds["magpie"].feature_labels() - values = self.cds["magpie"].featurize(struct.composition) - for label, value in zip(labels, values): - comp_descr[0][label] = value - descr_doc["composition_descriptors"] = comp_descr - except Exception as e: - self.logger.error("Failed getting Magpie descriptors: " "{}".format(e)) - descr_doc["site_descriptors"] = self.get_site_descriptors_from_struct( - descr_doc["structure"] - ) - descr_doc["statistics"] = self.get_statistics(descr_doc["site_descriptors"]) - descr_doc[self.descriptors.key] = item[self.materials.key] - - return descr_doc - - def get_site_descriptors_from_struct(self, structure): - doc = {} - - # Compute descriptors. - for k, sd in self.sds.items(): - try: - d = [] - l = sd.feature_labels() - for i, s in enumerate(structure.sites): - d.append({"site": i}) - for j, desc in enumerate(sd.featurize(structure, i)): - d[i][l[j]] = desc - doc[k] = d - - except Exception as e: - self.logger.error( - "Failed calculating {} site-descriptors: " "{}".format(k, e) - ) - - return doc - - def get_statistics(self, site_descr, fps=("csf",)): - doc = {} - - # Compute site-descriptor statistics. - for fp in fps: - doc[fp] = {} - try: - n_site = len(site_descr[fp]) - tmp = {} - for isite in range(n_site): - for l, v in site_descr[fp][isite].items(): - if l not in list(tmp.keys()): - tmp[l] = [] - tmp[l].append(v) - d = [] - for k, l in tmp.items(): - dtmp = {"name": k} - dtmp["mean"] = np.mean(tmp[k]) - dtmp["std"] = np.std(tmp[k]) - d.append(dtmp) - doc[fp] = d - - except Exception as e: - self.logger.error( - "Failed calculating statistics of site " "descriptors: {}".format(e) - ) - - return doc diff --git a/emmet-builders/emmet/builders/materials/bonds.py b/emmet-builders/emmet/builders/materials/bonds.py index 3ea6fd0120..7ce63e214a 100644 --- a/emmet-builders/emmet/builders/materials/bonds.py +++ b/emmet-builders/emmet/builders/materials/bonds.py @@ -1,65 +1,46 @@ -import warnings - -from maggma.builders.map_builder import MapBuilder -from maggma.core import Store -from pymatgen.core import Structure from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +from emmet.builders.base import BaseBuilderInput +from emmet.builders.utils import filter_map, try_call from emmet.core.bonds import BondingDoc -from emmet.core.utils import jsanitize - -warnings.warn( - f"The current version of {__name__}.BondingBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) -class BondingBuilder(MapBuilder): - def __init__( - self, - oxidation_states: Store, - bonding: Store, - **kwargs, - ): - """ - Creates Bonding documents from structures, ideally with - oxidation states already annotated but will also work from any - collection with structure and mp-id. +def build_bonding_docs( + input_documents: list[BaseBuilderInput], **kwargs +) -> list[BondingDoc]: + """ + Generate bonding documents from input structures. - Args: - oxidation_states: Store of oxidation states - bonding: Store to update with bonding documents - query : query on materials to limit search - """ - self.oxidation_states = oxidation_states - self.bonding = bonding - self.kwargs = kwargs + Transforms a list of BaseBuilderInput documents containing + Pymatgen structures into corresponding BondingDoc instances by + analyzing the bonding environment of each structure. - # Enforce that we key on material_id - self.oxidation_states.key = "material_id" - self.bonding.key = "material_id" - super().__init__( - source=oxidation_states, - target=bonding, - projection=["structure", "deprecated"], - **kwargs, - ) + Caller is responsible for creating BaseBuilderInput instances + within their data pipeline context. - def unary_function(self, item): - structure = Structure.from_dict(item["structure"]) - mpid = item["material_id"] - deprecated = item["deprecated"] + Args: + input_documents: List of BaseBuilderInput documents to process. - # temporarily convert to conventional structure inside this builder, - # in future do structure setting operations in a separate builder - structure = SpacegroupAnalyzer(structure).get_conventional_standard_structure() + Returns: + list[BondingDoc] + """ - bonding_doc = BondingDoc.from_structure( - structure=structure, material_id=mpid, deprecated=deprecated + def _build(deprecated: bool, material_id: str, structure, **kwargs) -> BondingDoc: + return BondingDoc.from_structure( + deprecated=deprecated, + material_id=material_id, + structure=try_call( + lambda s: SpacegroupAnalyzer(s).get_conventional_standard_structure(), + structure, + ), + **kwargs ) - doc = jsanitize(bonding_doc.model_dump(), allow_bson=True) - return doc + return list( + filter_map( + _build, + input_documents, + work_keys=["deprecated", "material_id", "structure"], + **kwargs + ) + ) diff --git a/emmet-builders/emmet/builders/materials/chemenv.py b/emmet-builders/emmet/builders/materials/chemenv.py index b40b0e59ee..a82a6f73d0 100644 --- a/emmet-builders/emmet/builders/materials/chemenv.py +++ b/emmet-builders/emmet/builders/materials/chemenv.py @@ -1,53 +1,33 @@ -import warnings - -from maggma.builders.map_builder import MapBuilder -from maggma.core import Store -from pymatgen.core.structure import Structure - +from emmet.builders.base import BaseBuilderInput +from emmet.builders.utils import filter_map from emmet.core.chemenv import ChemEnvDoc -from emmet.core.utils import jsanitize -warnings.warn( - f"The current version of {__name__}.ChemEnvBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) +def build_chemenv_docs( + input_documents: list[BaseBuilderInput], **kwargs +) -> list[ChemEnvDoc]: + """ + Generate chemical environment documents from input structures. -class ChemEnvBuilder(MapBuilder): - def __init__( - self, - oxidation_states: Store, - chemenv: Store, - query: dict | None = None, - **kwargs, - ): - self.oxidation_states = oxidation_states - self.chemenv = chemenv - self.kwargs = kwargs + Transforms a list of BaseBuilderInput documents containing + Pymatgen structures into corresponding ChemEnvDoc instances by + analyzing the chemical environment of each structure. - self.chemenv.key = "material_id" - self.oxidation_states.key = "material_id" + Caller is responsible for creating BaseBuilderInput instances + within their data pipeline context. - super().__init__( - source=oxidation_states, - target=chemenv, - query=query, - projection=["material_id", "structure", "deprecated"], - **kwargs, - ) + Args: + input_documents: List of BaseBuilderInput documents to process. - def unary_function(self, item): - structure = Structure.from_dict(item["structure"]) - mpid = item["material_id"] - deprecated = item["deprecated"] + Returns: + list[ChemEnvDoc] + """ - doc = ChemEnvDoc.from_structure( - structure=structure, - material_id=mpid, - deprecated=deprecated, + return list( + filter_map( + ChemEnvDoc.from_structure, + input_documents, + work_keys=["deprecated", "material_id", "structure", "builder_meta"], + **kwargs ) - - return jsanitize(doc.model_dump(), allow_bson=True) + ) diff --git a/emmet-builders/emmet/builders/materials/corrected_entries.py b/emmet-builders/emmet/builders/materials/corrected_entries.py index 2dfae05b92..f049e283de 100644 --- a/emmet-builders/emmet/builders/materials/corrected_entries.py +++ b/emmet-builders/emmet/builders/materials/corrected_entries.py @@ -1,336 +1,107 @@ -from __future__ import annotations - import copy import warnings -from collections import defaultdict -from datetime import datetime -from itertools import chain -from math import ceil -from typing import TYPE_CHECKING -from maggma.core import Builder, Store -from maggma.utils import grouper +from pydantic import BaseModel, Field from pymatgen.entries.compatibility import Compatibility -from pymatgen.entries.computed_entries import ComputedStructureEntry -from emmet.builders.utils import HiddenPrints, chemsys_permutations +from emmet.builders.utils import HiddenPrints from emmet.core.corrected_entries import CorrectedEntriesDoc -from emmet.core.types.enums import ThermoType -from emmet.core.utils import jsanitize - -if TYPE_CHECKING: - from collections.abc import Iterable, Iterator - -warnings.warn( - f"The current version of {__name__}.CorrectedEntriesBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, +from emmet.core.thermo import ThermoType +from emmet.core.types.pymatgen_types.computed_entries_adapter import ( + ComputedStructureEntryType, ) -class CorrectedEntriesBuilder(Builder): - def __init__( - self, - materials: Store, - corrected_entries: Store, - oxidation_states: Store | None = None, - query: dict | None = None, - compatibility: list[Compatibility] | list[None] | None = [None], - chunk_size: int = 1000, - **kwargs, - ): - """ - Produces corrected thermo entry data from uncorrected materials entries. - This is meant to be an intermediate builder for the main thermo builder. - - Args: - materials (Store): Store of materials documents - corrected_entries (Store): Store to output corrected entry data - oxidation_states (Store): Store of oxidation state data to use in correction scheme application - query (dict): dictionary to limit materials to be analyzed - compatibility ([Compatibility]): Compatibility module - to ensure energies are compatible - chunk_size (int): Size of chemsys chunks to process at any one time. - """ - - self.materials = materials - self.query = query if query else {} - self.corrected_entries = corrected_entries - self.compatibility = compatibility - self.oxidation_states = oxidation_states - self.chunk_size = chunk_size - self._entries_cache: dict[str, list[dict]] = defaultdict(list) - - if self.corrected_entries.key != "chemsys": - warnings.warn( - "Key for the corrected_entries store is incorrect and has been changed " - f"from {self.corrected_entries.key} to thermo_id!" - ) - self.corrected_entries.key = "chemsys" - - if self.materials.key != "material_id": - warnings.warn( - f"Key for the materials store is incorrect and has been changed from {self.materials.key} to material_id!" # noqa: E501 - ) - self.materials.key = "material_id" - - sources = [materials] - - if self.oxidation_states is not None: - if self.oxidation_states.key != "material_id": - warnings.warn( - f"Key for the oxidation states store is incorrect and has been changed from {self.oxidation_states.key} to material_id!" # noqa:E501 - ) - self.oxidation_states.key = "material_id" - - sources.append(oxidation_states) # type: ignore - - targets = [corrected_entries] - - super().__init__( - sources=sources, targets=targets, chunk_size=chunk_size, **kwargs - ) - - def ensure_indexes(self): - """ - Ensures indicies on the tasks and materials collections - """ - - # Search index for materials - self.materials.ensure_index("material_id") - self.materials.ensure_index("chemsys") - self.materials.ensure_index("last_updated") - - # Search index for corrected_entries - self.corrected_entries.ensure_index("chemsys") - - def prechunk(self, number_splits: int) -> Iterable[dict]: # pragma: no cover - to_process_chemsys = self._get_chemsys_to_process() - - N = ceil(len(to_process_chemsys) / number_splits) - - for chemsys_chunk in grouper(to_process_chemsys, N): - yield {"query": {"chemsys": {"$in": list(chemsys_chunk)}}} - - def get_items(self) -> Iterator[list[dict]]: - """ - Gets whole chemical systems of entries to process - """ - - self.logger.info("Corrected Entries Builder Started") - - self.logger.info("Setting indexes") - self.ensure_indexes() - - to_process_chemsys = self._get_chemsys_to_process() - - self.logger.info( - f"Processing entries in {len(to_process_chemsys)} chemical systems" - ) - self.total = len(to_process_chemsys) - - # Yield the chemical systems in order of increasing size - for chemsys in sorted( - to_process_chemsys, key=lambda x: len(x.split("-")), reverse=False - ): - entries = self.get_entries(chemsys) - yield entries - - def process_item(self, item): - """ - Applies correction schemes to entries and constructs CorrectedEntriesDoc objects - """ - - if not item: - return None - - entries = [ComputedStructureEntry.from_dict(entry) for entry in item] - # determine chemsys - elements = sorted( - set([el.symbol for e in entries for el in e.composition.elements]) - ) - chemsys = "-".join(elements) - - self.logger.debug(f"Processing {len(entries)} entries for {chemsys}") - - all_entry_types = {str(e.data["run_type"]) for e in entries} +class CorrectedEntriesBuilderInput(BaseModel): + entries: list[ComputedStructureEntryType] = Field( + ..., + description=""" + List of computed structure entries to apply corrections to. + Entries MUST belong to a single chemical system (chemsys). + """, + ) + + +def build_corrected_entries_doc( + input: CorrectedEntriesBuilderInput, + compatibilities: list[Compatibility | None] = [None], +) -> CorrectedEntriesDoc: + """ + Process computed structure entries using corrections defined in pymatgen + compatibility classes. Ensures compatibility of energies for entries for + different thermodynamic hulls. + + Input entries must all belong to the same chemical system. Caller is + responsible for constructing CorrectedEntriesBuilderInput instances within + their data pipeline context. + + Args: + input: CorrectedEntriesBuilderInput with an aggregated list of computed + structure entries for a single chemical system. + compatibilities: List of pymatgen compatibility classes to apply to + input entries. + + Returns: + CorrectedEntriesDoc: if no Compatibility class(es) are provided, and all + entries have the same functional, no corrections will be applied and + entries will simply be passed through to CorrectedEntriesDoc constructor. + """ + all_entry_types = {str(e.data["run_type"]) for e in input.entries} + + elements = sorted( + set([el.symbol for e in input.entries for el in e.composition.elements]) + ) + chemsys = "-".join(elements) + + corrected_entries = {} + + for compatibility in compatibilities: + if compatibility is not None: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + with HiddenPrints(): + if compatibility.name == "MP DFT mixing scheme": + thermo_type = ThermoType.GGA_GGA_U_R2SCAN + + if "R2SCAN" in all_entry_types: + only_scan_pd_entries = [ + e + for e in input.entries + if str(e.data["run_type"]) == "R2SCAN" + ] + corrected_entries["R2SCAN"] = only_scan_pd_entries - corrected_entries = {} - - for compatibility in self.compatibility: - if compatibility is not None: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - with HiddenPrints(): - if compatibility.name == "MP DFT mixing scheme": - thermo_type = ThermoType.GGA_GGA_U_R2SCAN - - if "R2SCAN" in all_entry_types: - only_scan_pd_entries = [ - e - for e in entries - if str(e.data["run_type"]) == "R2SCAN" - ] - corrected_entries["R2SCAN"] = only_scan_pd_entries - - pd_entries = compatibility.process_entries( - copy.deepcopy(entries) - ) - - else: - corrected_entries["R2SCAN"] = None - pd_entries = None - - elif compatibility.name == "MP2020": - thermo_type = ThermoType.GGA_GGA_U - pd_entries = compatibility.process_entries( - copy.deepcopy(entries) - ) - else: - thermo_type = ThermoType.UNKNOWN pd_entries = compatibility.process_entries( - copy.deepcopy(entries) + copy.deepcopy(input.entries), + verbose=False, ) - corrected_entries[str(thermo_type)] = pd_entries - - else: - if len(all_entry_types) > 1: - raise ValueError( - "More than one functional type has been provided without a mixing scheme!" - ) - else: - thermo_type = all_entry_types.pop() - - corrected_entries[str(thermo_type)] = copy.deepcopy(entries) - - doc = CorrectedEntriesDoc(chemsys=chemsys, entries=corrected_entries) - - return jsanitize(doc.model_dump(), allow_bson=True) - - def update_targets(self, items): - """ - Inserts the new corrected entry docs into the corrected entries collection - """ - - docs = list(filter(None, items)) + else: + corrected_entries["R2SCAN"] = None + pd_entries = None + + elif compatibility.name == "MP2020": + thermo_type = ThermoType.GGA_GGA_U + pd_entries = compatibility.process_entries( + copy.deepcopy(input.entries), verbose=False + ) + else: + thermo_type = ThermoType.UNKNOWN + pd_entries = compatibility.process_entries( + copy.deepcopy(input.entries), verbose=False + ) + + corrected_entries[str(thermo_type)] = pd_entries - if len(docs) > 0: - self.logger.info(f"Updating {len(docs)} corrected entry documents") - self.corrected_entries.update(docs=docs, key=["chemsys"]) else: - self.logger.info("No corrected entry items to update") - - def get_entries(self, chemsys: str) -> list[dict]: - """ - Gets entries from the materials collection for the corresponding chemical systems - Args: - chemsys (str): a chemical system represented by string elements seperated by a dash (-) - Returns: - set (ComputedEntry): a set of entries for this system - """ - - self.logger.info(f"Getting entries for: {chemsys}") - # First check the cache - all_chemsys = chemsys_permutations(chemsys) - cached_chemsys = all_chemsys & set(self._entries_cache.keys()) - query_chemsys = all_chemsys - cached_chemsys - all_entries = list( - chain.from_iterable(self._entries_cache[c] for c in cached_chemsys) - ) - - self.logger.debug( - f"Getting {len(cached_chemsys)} sub-chemsys from cache for {chemsys}" - ) - self.logger.debug( - f"Getting {len(query_chemsys)} sub-chemsys from DB for {chemsys}" - ) - - # Second grab the materials docs - new_q = dict(self.query) - new_q["chemsys"] = {"$in": list(query_chemsys)} - new_q["deprecated"] = False - - materials_docs = list( - self.materials.query( - criteria=new_q, - properties=["material_id", "entries", "deprecated", "builder_meta"], - ) - ) - - # Get Oxidation state data for each material - oxi_states_data = {} - if self.oxidation_states: - material_ids = [t["material_id"] for t in materials_docs] - oxi_states_data = { - d["material_id"]: d.get("average_oxidation_states", {}) - for d in self.oxidation_states.query( - properties=[ - "material_id", - "average_oxidation_states", - ], - criteria={ - "material_id": {"$in": material_ids}, - "state": "successful", - }, + if len(all_entry_types) > 1: + # TODO: logging over raising + raise ValueError( + "More than one functional type has been provided without a mixing scheme!" ) - } - - self.logger.debug( - f"Got {len(materials_docs)} entries from DB for {len(query_chemsys)} sub-chemsys for {chemsys}" - ) - - # Convert entries into ComputedEntries and store - for doc in materials_docs: - for r_type, entry_dict in doc.get("entries", {}).items(): - if entry_dict: - entry_dict["data"]["oxidation_states"] = oxi_states_data.get( - entry_dict["data"]["material_id"], {} - ) - entry_dict["data"]["license"] = doc["builder_meta"].get("license") - entry_dict["data"]["run_type"] = r_type - elsyms = sorted(set([el for el in entry_dict["composition"]])) - self._entries_cache["-".join(elsyms)].append(entry_dict) - all_entries.append(entry_dict) - - self.logger.info(f"Total entries in {chemsys} : {len(all_entries)}") - - return all_entries - - def _get_chemsys_to_process(self): - # Use last-updated to find new chemsys - materials_chemsys_dates = {} - for d in self.materials.query( - {"deprecated": False, **self.query}, - properties=[self.corrected_entries.key, self.materials.last_updated_field], - ): - entry = materials_chemsys_dates.get(d[self.corrected_entries.key], None) - if entry is None or d[self.materials.last_updated_field] > entry: - materials_chemsys_dates[d[self.corrected_entries.key]] = d[ - self.materials.last_updated_field - ] - - corrected_entries_chemsys_dates = { - d[self.corrected_entries.key]: d[self.corrected_entries.last_updated_field] - for d in self.corrected_entries.query( - {}, - properties=[ - self.corrected_entries.key, - self.corrected_entries.last_updated_field, - ], - ) - } + else: + thermo_type = all_entry_types.pop() - to_process_chemsys = [ - chemsys - for chemsys in materials_chemsys_dates - if (chemsys not in corrected_entries_chemsys_dates) - or ( - materials_chemsys_dates[chemsys] - > datetime.fromisoformat(corrected_entries_chemsys_dates[chemsys]) - ) - ] + corrected_entries[str(thermo_type)] = copy.deepcopy(input.entries) - return to_process_chemsys + return CorrectedEntriesDoc(chemsys=chemsys, entries=corrected_entries) diff --git a/emmet-builders/emmet/builders/materials/dielectric.py b/emmet-builders/emmet/builders/materials/dielectric.py deleted file mode 100644 index f44fdf4015..0000000000 --- a/emmet-builders/emmet/builders/materials/dielectric.py +++ /dev/null @@ -1,222 +0,0 @@ -from __future__ import annotations - -import warnings -from math import ceil -from typing import TYPE_CHECKING - -import numpy as np -from maggma.builders import Builder -from maggma.core import Store -from maggma.utils import grouper -from pymatgen.core.structure import Structure - -from emmet.core.polar import DielectricDoc -from emmet.core.utils import jsanitize - -if TYPE_CHECKING: - from collections.abc import Iterator - -warnings.warn( - f"The current version of {__name__}.DielectricBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class DielectricBuilder(Builder): - def __init__( - self, - materials: Store, - tasks: Store, - dielectric: Store, - query: dict | None = None, - **kwargs, - ): - self.materials = materials - self.tasks = tasks - self.dielectric = dielectric - self.query = query or {} - self.kwargs = kwargs - - self.materials.key = "material_id" - self.tasks.key = "task_id" - self.dielectric.key = "material_id" - - super().__init__(sources=[materials, tasks], targets=[dielectric], **kwargs) - - def prechunk(self, number_splits: int) -> Iterator[dict]: # pragma: no cover - """ - Prechunk method to perform chunking by the key field - """ - q = dict(self.query) - - keys = self.dielectric.newer_in(self.materials, criteria=q, exhaustive=True) - - N = ceil(len(keys) / number_splits) - for split in grouper(keys, N): - yield {"query": {self.materials.key: {"$in": list(split)}}} - - def get_items(self): - """ - Gets all items to process - - Returns: - generator or list relevant tasks and materials to process - """ - - self.logger.info("Dielectric Builder Started") - - q = dict(self.query) - - mat_ids = self.materials.distinct(self.materials.key, criteria=q) - di_ids = self.dielectric.distinct(self.dielectric.key) - - mats_set = set( - self.dielectric.newer_in(target=self.materials, criteria=q, exhaustive=True) - ) | (set(mat_ids) - set(di_ids)) - - mats = [mat for mat in mats_set] - - self.logger.info( - "Processing {} materials for dielectric data".format(len(mats)) - ) - - self.total = len(mats) - - for mat in mats: - doc = self._get_processed_doc(mat) - - if doc is not None: - yield doc - else: - pass - - def process_item(self, item): - structure = Structure.from_dict(item["structure"]) - mpid = item["material_id"] - origin_entry = { - "name": "dielectric", - "task_id": item["task_id"], - "last_updated": item["task_updated"], - } - - doc = DielectricDoc.from_ionic_and_electronic( - structure=structure, - material_id=mpid, - origins=[origin_entry], - deprecated=False, - ionic=item["epsilon_ionic"], - electronic=item["epsilon_static"], - last_updated=item["updated_on"], - ) - - return jsanitize(doc.model_dump(), allow_bson=True) - - def update_targets(self, items): - """ - Inserts the new dielectric docs into the dielectric collection - """ - docs = list(filter(None, items)) - - if len(docs) > 0: - self.logger.info(f"Found {len(docs)} dielectric docs to update") - self.dielectric.update(docs) - else: - self.logger.info("No items to update") - - def _get_processed_doc(self, mat): - mat_doc = self.materials.query_one( - {self.materials.key: mat}, - [ - self.materials.key, - "structure", - "task_types", - "run_types", - "deprecated_tasks", - "last_updated", - ], - ) - - task_types = mat_doc["task_types"].items() - - potential_task_ids = [] - - for task_id, task_type in task_types: - if task_type == "DFPT Dielectric": - if task_id not in mat_doc["deprecated_tasks"]: - potential_task_ids.append(task_id) - - final_docs = [] - - for task_id in potential_task_ids: - task_query = self.tasks.query_one( - properties=[ - "last_updated", - "input.is_hubbard", - "orig_inputs.kpoints", - "orig_inputs.structure", - "input.parameters", - "input.structure", - "calcs_reversed", - "output.bandgap", - ], - criteria={self.tasks.key: str(task_id)}, - ) - - if task_query["output"]["bandgap"] > 0: - try: - structure = task_query["input"]["structure"] - except KeyError: - structure = task_query["orig_inputs"]["structure"] - - is_hubbard = task_query["input"]["is_hubbard"] - - if ( - task_query["orig_inputs"]["kpoints"]["generation_style"] - == "Monkhorst" - or task_query["orig_inputs"]["kpoints"]["generation_style"] - == "Gamma" - ): - nkpoints = np.prod( - task_query["orig_inputs"]["kpoints"]["kpoints"][0], axis=0 - ) - - else: - nkpoints = task_query["orig_inputs"]["kpoints"]["nkpoints"] - - lu_dt = mat_doc["last_updated"] - task_updated = task_query["last_updated"] - - final_docs.append( - { - "task_id": task_id, - "is_hubbard": int(is_hubbard), - "nkpoints": int(nkpoints), - "epsilon_static": task_query["calcs_reversed"][0]["output"][ - "epsilon_static" - ], - "epsilon_ionic": task_query["calcs_reversed"][0]["output"][ - "epsilon_ionic" - ], - "structure": structure, - "updated_on": lu_dt, - "task_updated": task_updated, - self.materials.key: mat_doc[self.materials.key], - } - ) - - if len(final_docs) > 0: - sorted_final_docs = sorted( - final_docs, - key=lambda entry: ( - entry["is_hubbard"], - entry["nkpoints"], - entry["updated_on"], - ), - reverse=True, - ) - return sorted_final_docs[0] - else: - return None diff --git a/emmet-builders/emmet/builders/materials/elasticity.py b/emmet-builders/emmet/builders/materials/elasticity.py deleted file mode 100644 index e34eb203a4..0000000000 --- a/emmet-builders/emmet/builders/materials/elasticity.py +++ /dev/null @@ -1,507 +0,0 @@ -""" -Builder to generate elasticity docs. - -The build proceeds in the below steps: -1. Use materials builder to group tasks according the formula, space group, and - structure matching. -2. Filter opt and deform tasks by calc type. -3. Filter opt and deform tasks to match prescribed INCAR params. -4. Group opt tasks by optimized lattice, and, for each group, select the latest task - (the one with the newest completing time). This result in a {lat, opt_task} dict. -5. Group deform tasks by parent lattice (i.e. lattice before a deformation gradient is - applied). For each lattice group, then group the tasks by deformation gradient, - and select the latest task for each deformation gradient. This result in a {lat, - [deform_task]} dict, where [deform_task] are tasks with unique deformation gradients. -6. Associate opt and deform tasks by matching parent lattice. Then select the one with - the most deformation tasks as the final data for fitting the elastic tensor. -7. Fit the elastic tensor. -""" - -from __future__ import annotations - -import warnings -from datetime import datetime -from typing import TYPE_CHECKING - -import numpy as np -from maggma.core import Builder, Store -from pydash.objects import get -from pymatgen.analysis.elasticity.strain import Deformation -from pymatgen.analysis.elasticity.stress import Stress -from pymatgen.core import Structure -from pymatgen.core.tensors import TensorMapping - -from emmet.core.elasticity import ElasticityDoc -from emmet.core.mpid import AlphaID -from emmet.core.utils import jsanitize -from emmet.core.vasp.calc_types import CalcType - -if TYPE_CHECKING: - from collections.abc import Generator - from typing import Any - - from emmet.core.types.typing import IdentifierType - -warnings.warn( - f"The current version of {__name__}.ElasticityBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class ElasticityBuilder(Builder): - def __init__( - self, - tasks: Store, - materials: Store, - elasticity: Store, - query: dict | None = None, - fitting_method: str = "finite_difference", - **kwargs, - ): - """ - Creates an elastic collection for materials. - - Args: - tasks: Store of tasks - materials: Store of materials - elasticity: Store of elasticity - query: Mongo-like query to limit the tasks to be analyzed - fitting_method: method to fit the elastic tensor: {`finite_difference`, - `pseudoinverse`, `independent`} - """ - - self.tasks = tasks - self.materials = materials - self.elasticity = elasticity - self.query = query if query is not None else {} - self.fitting_method = fitting_method - self.kwargs = kwargs - - super().__init__(sources=[tasks, materials], targets=[elasticity], **kwargs) - - def ensure_index(self): - self.tasks.ensure_index("nsites") - self.tasks.ensure_index("formula_pretty") - self.tasks.ensure_index("last_updated") - - self.materials.ensure_index("material_id") - self.materials.ensure_index("last_updated") - - self.elasticity.ensure_index("material_id") - self.elasticity.ensure_index("last_updated") - - def get_items( - self, - ) -> Generator[tuple[str, dict[str, str], list[dict]], None, None]: - """ - Gets all items to process into elasticity docs. - - Returns: - material_id: material id for the tasks - calc_types: calculation types of the tasks - tasks: task docs belong to the same material - """ - - self.logger.info("Elastic Builder Started") - - self.ensure_index() - - cursor = self.materials.query( - criteria=self.query, properties=["material_id", "calc_types", "task_ids"] - ) - - # query for tasks - # query = self.query.copy() - tasks_query = {} - - for i, doc in enumerate(cursor): - material_id = doc["material_id"] - calc_types = {str(k): v for k, v in doc["calc_types"].items()} - - self.logger.debug(f"Querying tasks for material {material_id} (index {i}).") - - # update query with task_ids - try: - ids_list = [int(i) for i in doc["task_ids"]] - except ValueError: - ids_list = [i for i in doc["task_ids"]] - - tasks_query["task_id"] = {"$in": ids_list} - - projections = [ - "output", - "orig_inputs", - "completed_at", - "transformations", - "task_id", - "dir_name", - ] - - task_cursor = self.tasks.query(criteria=tasks_query, properties=projections) - tasks = list(task_cursor) - - yield material_id, calc_types, tasks - - def process_item( - self, item: tuple[IdentifierType, dict[str, str], list[dict]] - ) -> dict | None: - """ - Process all tasks belong to the same material into an elasticity doc. - - Args: - item: - material_id: material id for the tasks - calc_types: {task_id: task_type} calculation types of the tasks - tasks: task docs belong to the same material - - Returns: - Elasticity doc obtained from the list of tasks. `None` if failed to - obtain the elasticity doc from the tasks. - """ - - material_id, calc_types, tasks = item - - if len(tasks) != len(calc_types): - self.logger.error( - f"Number of tasks ({len(tasks)}) is not equal to number of calculation " - f"types ({len(calc_types)}) for material with material id " - f"{material_id}. Cannot proceed." - ) - return None - - # filter by calc type - opt_tasks = filter_opt_tasks(tasks, calc_types) - deform_tasks = filter_deform_tasks(tasks, calc_types) - if not opt_tasks or not deform_tasks: - return None - - # filter by incar - opt_tasks = filter_by_incar_settings(opt_tasks) - deform_tasks = filter_by_incar_settings(deform_tasks) - if not opt_tasks or not deform_tasks: - return None - - # select one task for each set of optimization tasks with the same lattice - opt_grouped_tmp = group_by_parent_lattice(opt_tasks, mode="opt") - opt_grouped = [ - (lattice, filter_opt_tasks_by_time(tasks, self.logger)) - for lattice, tasks in opt_grouped_tmp - ] - - # for deformed tasks with the same lattice, select one if there are multiple - # tasks with the same deformation - deform_grouped = group_by_parent_lattice(deform_tasks, mode="deform") - deform_grouped = [ - (lattice, filter_deform_tasks_by_time(tasks, logger=self.logger)) - for lattice, tasks in deform_grouped - ] - - # select opt and deform tasks for fitting - final_opt, final_deform = select_final_opt_deform_tasks( - opt_grouped, deform_grouped, self.logger - ) - if final_opt is None or final_deform is None: - return None - - # convert to elasticity doc - deforms = [] - stresses: list[Stress] = [] # TODO: mypy misfires on `Stress` - deform_task_ids = [] - deform_dir_names = [] - for doc in final_deform: - deforms.append( - Deformation(doc["transformations"]["history"][0]["deformation"]) - ) - # 0.1 to convert to GPa from kBar, and the minus sign to flip the stress - # direction from compressive as positive (in vasp) to tensile as positive - stresses.append(-0.1 * Stress(doc["output"]["stress"])) # type: ignore[arg-type] - deform_task_ids.append(doc["task_id"]) - deform_dir_names.append(doc["dir_name"]) - - elasticity_doc = ElasticityDoc.from_deformations_and_stresses( - structure=Structure.from_dict(final_opt["output"]["structure"]), - material_id=material_id, - deformations=deforms, - stresses=stresses, # type: ignore[arg-type] - deformation_task_ids=deform_task_ids, - deformation_dir_names=deform_dir_names, - equilibrium_stress=-0.1 * Stress(final_opt["output"]["stress"]), # type: ignore[arg-type] - optimization_task_id=final_opt["task_id"], - optimization_dir_name=final_opt["dir_name"], - fitting_method="finite_difference", - ) - elasticity_doc = jsanitize(elasticity_doc.model_dump(), allow_bson=True) - - return elasticity_doc - - def update_targets(self, items: list[dict]): - """ - Insert the new elasticity docs into the elasticity collection. - - Args: - items: elasticity docs - """ - self.logger.info(f"Updating {len(items)} elasticity documents") - - self.elasticity.update(items, key="material_id") - - -def filter_opt_tasks( - tasks: list[dict], - calc_types: dict[str, str], - target_calc_type: str | CalcType = CalcType.GGA_Structure_Optimization, -) -> list[dict]: - """ - Filter optimization tasks, by - - calculation type - """ - opt_tasks = [ - t for t in tasks if calc_types[str(AlphaID(t["task_id"]))] == target_calc_type - ] - - return opt_tasks - - -def filter_deform_tasks( - tasks: list[dict], - calc_types: dict[str, str], - target_calc_type: str | CalcType = CalcType.GGA_Deformation, -) -> list[dict]: - """ - Filter deformation tasks, by - - calculation type - - number of transformations - - transformation class - """ - deform_tasks = [] - for t in tasks: - if calc_types[str(t["task_id"])] == target_calc_type: - transforms = t.get("transformations", {}).get("history", []) - if ( - len(transforms) == 1 - and transforms[0]["@class"] == "DeformStructureTransformation" - ): - deform_tasks.append(t) - - return deform_tasks - - -def filter_by_incar_settings( - tasks: list[dict], incar_settings: dict[str, Any] | None = None -) -> list[dict]: - """ - Filter tasks by incar parameters. - """ - - if incar_settings is None: - incar_settings = { - "LREAL": False, - "ENCUT": 700, - "PREC": "Accurate", - "EDIFF": 1e-6, - } - - selected = [] - for t in tasks: - incar = t["orig_inputs"]["incar"] - ok = True - for k, v in incar_settings.items(): - if k not in incar: - ok = False - break - - if isinstance(incar[k], str): - if incar[k].lower() != str(v).lower(): - ok = False - break - - elif isinstance(incar[k], float): - if not np.allclose(incar[k], v, atol=1e-10): - ok = False - break - - else: - if incar[k] != v: - ok = False - break - - if ok: - selected.append(t) - - return selected - - -def filter_opt_tasks_by_time(tasks: list[dict], logger) -> dict: - """ - Filter a set of tasks to select the latest completed one. - - Args: - tasks: the set of tasks to filter - logger: - - Returns: - selected latest task - """ - return _filter_tasks_by_time(tasks, "optimization", logger) - - -def filter_deform_tasks_by_time( - tasks: list[dict], deform_comp_tol: float = 1e-5, logger=None -) -> list[dict]: - """ - For deformation tasks with the same deformation, select the latest completed one. - - Args: - tasks: the deformation tasks - deform_comp_tol: tolerance for comparing deformation equivalence - - Returns: - filtered deformation tasks - """ - - mapping = TensorMapping(tol=deform_comp_tol) - - # group tasks by deformation - for doc in tasks: - # assume only one deformation, should be checked in `filter_deform_tasks()` - deform = doc["transformations"]["history"][0]["deformation"] - - if deform in mapping: - mapping[deform].append(doc) - else: - mapping[deform] = [doc] - - # select the latest task for each deformation - selected = [] - for docs in mapping.values(): - t = _filter_tasks_by_time(docs, "deformation", logger) - selected.append(t) - - return selected - - -def _filter_tasks_by_time(tasks: list[dict], mode: str, logger) -> dict: - """ - Helper function to filter a set of tasks to select the latest completed one. - """ - if len(tasks) == 0: - raise RuntimeError(f"Cannot filter {mode} task from 0 input tasks") - elif len(tasks) == 1: - return tasks[0] - - completed = [(datetime.fromisoformat(t["completed_at"]), t) for t in tasks] - sorted_by_completed = sorted(completed, key=lambda pair: pair[0]) - latest_pair = sorted_by_completed[-1] - selected = latest_pair[1] - - task_ids = [t["task_id"] for t in tasks] - logger.info( - f"Found multiple {mode} tasks {task_ids}; selected the latest task " - f"{selected['task_id']} that is completed at {selected['completed_at']}." - ) - - return selected - - -def select_final_opt_deform_tasks( - opt_tasks: list[tuple[np.ndarray, dict]], - deform_tasks: list[tuple[np.ndarray, list[dict]]], - logger, - lattice_comp_tol: float = 1e-5, -) -> tuple[dict | None, list[dict] | None]: - """ - Select the final opt task and deform tasks for fitting. - - This is achieved by selecting the opt--deform pairs with the same lattice, - and also with the most deform tasks. - - Returns: - final_opt_task: selected opt task - final_deform_tasks: selected deform tasks - """ - - # group opt and deform tasks by lattice - mapping = TensorMapping(tol=lattice_comp_tol) - for lat, opt_t in opt_tasks: - mapping[lat] = {"opt_task": opt_t} - - for lat, dt in deform_tasks: - if lat in mapping: - mapping[lat]["deform_tasks"] = dt - else: - mapping[lat] = {"deform_tasks": dt} - - # select opt--deform paris with the most deform tasks - selected = None - num_deform_tasks = -1 - for lat, tasks in mapping.items(): - if "opt_task" in tasks and "deform_tasks" in tasks: - n = len(tasks["deform_tasks"]) - if n > num_deform_tasks: - selected = (tasks["opt_task"], tasks["deform_tasks"]) - num_deform_tasks = n - - if selected is None: - tasks = [pair[1] for pair in opt_tasks] - for pair in deform_tasks: - tasks.extend(pair[1]) - - ids = [t["task_id"] for t in tasks] - logger.warning( - f"Cannot find optimization and deformation tasks that match by lattice " - f"for tasks {ids}" - ) - - final_opt_task = None - final_deform_tasks = None - else: - final_opt_task, final_deform_tasks = selected - - return final_opt_task, final_deform_tasks - - -def group_by_parent_lattice( - tasks: list[dict], mode: str, lattice_comp_tol: float = 1e-5 -) -> list[tuple[np.ndarray, list[dict]]]: - """ - Groups a set of task docs by parent lattice equivalence. - - Args: - tasks: task docs - mode: determines which lattice to use. If `opt`, use the lattice of the - output structure, and this is intended for optimization tasks. If - `deform`, use the lattice of the output structure and transform it by the - deformation in transformation, and this is intended for deformation tasks. - lattice_comp_tol: tolerance for comparing lattice equivalence. - - Returns: - [(lattice, list[tasks])]: each tuple gives the common parent lattice of a - list of the structures before deformation (if any), and the list tasks - from which the structures are taken. - """ - docs_by_lattice: list[tuple[np.ndarray, list[dict]]] = [] - - for doc in tasks: - sim_lattice = get(doc, "output.structure.lattice.matrix") - - if mode == "deform": - deform = doc["transformations"]["history"][0]["deformation"] - parent_lattice = np.dot(sim_lattice, np.transpose(np.linalg.inv(deform))) - elif mode == "opt": - parent_lattice = np.array(sim_lattice) - else: - raise ValueError(f"Unsupported mode {mode}") - - match = False - for unique_lattice, lattice_docs in docs_by_lattice: - match = np.allclose(unique_lattice, parent_lattice, atol=lattice_comp_tol) - if match: - lattice_docs.append(doc) - break - if not match: - docs_by_lattice.append((parent_lattice, [doc])) - - return docs_by_lattice diff --git a/emmet-builders/emmet/builders/materials/electrodes.py b/emmet-builders/emmet/builders/materials/electrodes.py deleted file mode 100644 index e57f2b2133..0000000000 --- a/emmet-builders/emmet/builders/materials/electrodes.py +++ /dev/null @@ -1,642 +0,0 @@ -import operator -import warnings -from collections import defaultdict -from datetime import datetime -from functools import lru_cache -from itertools import chain -from math import ceil -from typing import TYPE_CHECKING - -from maggma.builders import Builder -from maggma.stores import MongoStore -from maggma.utils import grouper -from pymatgen.analysis.phase_diagram import Composition, PhaseDiagram -from pymatgen.entries.compatibility import MaterialsProject2020Compatibility -from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry - -from emmet.builders.settings import EmmetBuildSettings -from emmet.core.electrode import ConversionElectrodeDoc, InsertionElectrodeDoc -from emmet.core.structure_group import StructureGroupDoc, _get_id_lexi -from emmet.core.utils import jsanitize - -if TYPE_CHECKING: - from collections.abc import Iterator - from typing import Any - -warnings.warn( - "The current versions of StructureGroupBuilder, InsertionElectrodeBuilder, and " - f"ConversionElectrodeBuilder in {__name__} will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -def s_hash(el): - return el.data["comp_delith"] - - -REDOX_ELEMENTS = [ - "Ti", - "V", - "Cr", - "Mn", - "Fe", - "Co", - "Ni", - "Cu", - "Nb", - "Mo", - "Ag", - "Sn", - "Sb", - "W", - "Re", - "Bi", - "C", -] - -WORKING_IONS = ["Li", "Na", "K", "Mg", "Ca", "Zn", "Al"] - -MAT_PROPS = ["structure", "material_id", "formula_pretty", "entries"] - -sg_fields = ["number", "hall_number", "international", "hall", "choice"] - - -def generic_groupby(list_in, comp=operator.eq): - """ - Group a list of unsortable objects - Args: - list_in: A list of generic objects - comp: (Default value = operator.eq) The comparator - Returns: - [int] list of labels for the input list - """ - list_out = [None] * len(list_in) - label_num = 0 - for i1, ls1 in enumerate(list_out): - if ls1 is not None: - continue - list_out[i1] = label_num - for i2, ls2 in list(enumerate(list_out))[i1 + 1 :]: - if comp(list_in[i1], list_in[i2]): - if list_out[i2] is None: - list_out[i2] = list_out[i1] - else: - list_out[i1] = list_out[i2] - label_num -= 1 - label_num += 1 - return list_out - - -default_build_settings = EmmetBuildSettings() - - -class StructureGroupBuilder(Builder): - def __init__( - self, - materials: MongoStore, - sgroups: MongoStore, - working_ion: str, - query: dict | None = None, - ltol: float = default_build_settings.LTOL, - stol: float = default_build_settings.STOL, - angle_tol: float = default_build_settings.ANGLE_TOL, - check_newer: bool = True, - chunk_size: int = 1000, - **kwargs, - ): - """ - Aggregate materials entries into sgroups that are topotactically similar to each other. - This is an incremental builder that makes ensures that each materials id belongs to one StructureGroupDoc - document - Args: - materials (Store): Store of materials documents that contains the structures - sgroups (Store): Store of grouped material ids - query (dict): dictionary to limit materials to be analyzed --- - only applied to the materials when we need to group structures - the phase diagram is still constructed with the entire set - chunk_size (int): Size of chemsys chunks to process at any one time. - """ - self.materials = materials - self.sgroups = sgroups - self.working_ion = working_ion - self.query = query if query else {} - self.ltol = ltol - self.stol = stol - self.angle_tol = angle_tol - self.check_newer = check_newer - self.chunk_size = chunk_size - - self.query["deprecated"] = ( - False # Ensure only non-deprecated materials are chosen - ) - - super().__init__( - sources=[materials], targets=[sgroups], chunk_size=chunk_size, **kwargs - ) - - def prechunk(self, number_splits: int) -> Iterator[dict]: # pragma: no cover - """ - Prechunk method to perform chunking by the key field - """ - q = dict(self.query) - - all_chemsys = self.materials.distinct("chemsys", criteria=q) - - new_chemsys_list = [] - - for chemsys in all_chemsys: - elements = [ - element for element in chemsys.split("-") if element != self.working_ion - ] - new_chemsys = "-".join(sorted(elements)) - new_chemsys_list.append(new_chemsys) - - N = ceil(len(new_chemsys_list) / number_splits) - - for split in grouper(new_chemsys_list, N): - new_split_add = [] - for chemsys in split: - elements = [element for element in chemsys.split("-")] + [ - self.working_ion - ] - new_chemsys = "-".join(sorted(elements)) - new_split_add.append(new_chemsys) - - yield {"query": {"chemsys": {"$in": new_split_add + split}}} - - def get_items(self): - """ - Summary of the steps: - - query the materials database for different chemical systems that satisfies the base query - "contains redox element and working ion" - - Get the full chemsys list of interest - - The main loop is over all these chemsys. within the main loop: - - get newest timestamp for the material documents (max_mat_time) - - get the oldest timestamp for the target documents (min_target_time) - - if min_target_time is < max_mat_time then nuke all the target documents - """ - # All potentially interesting chemsys must contain the working ion - base_query = { - "$and": [ - self.query.copy(), - {"elements": {"$in": REDOX_ELEMENTS}}, - {"elements": {"$in": [self.working_ion]}}, - ] - } - self.logger.debug(f"Initial Chemsys QUERY: {base_query}") - - # get a chemsys that only contains the working ion since the working ion - # must be present for there to be voltage steps - all_chemsys = self.materials.distinct("chemsys", criteria=base_query) - # Contains the working ion but not ONLY the working ion - all_chemsys = [ - *filter( - lambda x: self.working_ion in x and len(x) > 1, - [chemsys_.split("-") for chemsys_ in all_chemsys], - ) - ] - - self.logger.debug( - f"Performing initial checks on {len(all_chemsys)} chemical systems containing redox elements w/ or w/o wion" - ) - self.total = len(all_chemsys) - - for chemsys_l in all_chemsys: - chemsys = "-".join(sorted(chemsys_l)) - chemsys_wo = "-".join(sorted(set(chemsys_l) - {self.working_ion})) - chemsys_query = { - "$and": [{"chemsys": {"$in": [chemsys_wo, chemsys]}}, self.query.copy()] - } - self.logger.debug(f"QUERY: {chemsys_query}") - all_mats_in_chemsys = list( - self.materials.query( - criteria=chemsys_query, - properties=MAT_PROPS + [self.materials.last_updated_field], - ) - ) - self.logger.debug( - f"Found {len(all_mats_in_chemsys)} materials in {chemsys_wo}" - ) - if self.check_newer: - all_target_docs = list( - self.sgroups.query( - criteria={"chemsys": chemsys}, - properties=[ - "group_id", - self.sgroups.last_updated_field, - "material_ids", - ], - ) - ) - self.logger.debug( - f"Found {len(all_target_docs)} Grouped documents in {chemsys_wo}" - ) - - mat_times = [ - mat_doc[self.materials.last_updated_field] - for mat_doc in all_mats_in_chemsys - ] - max_mat_time = max(mat_times, default=datetime.min) - self.logger.debug( - f"The newest material doc was generated at {max_mat_time}." - ) - - target_times = [ - g_doc[self.materials.last_updated_field] - for g_doc in all_target_docs - ] - min_target_time = min(target_times, default=datetime.max) - self.logger.debug( - f"The newest GROUP doc was generated at {min_target_time}." - ) - - mat_ids = set( - [mat_doc["material_id"] for mat_doc in all_mats_in_chemsys] - ) - - # If any material id is missing or if any material id has been updated - target_ids = set() - for g_doc in all_target_docs: - target_ids |= set(g_doc["material_ids"]) - - self.logger.debug( - f"There are {len(mat_ids)} material ids in source database vs {len(target_ids)} in target database." - ) - if mat_ids == target_ids and max_mat_time < min_target_time: - self.logger.info(f"Skipping chemsys {chemsys}.") - yield None - elif len(target_ids) == 0: - self.logger.info( - f"No documents in chemsys {chemsys} in the target database." - ) - yield {"chemsys": chemsys, "materials": all_mats_in_chemsys} - else: - self.logger.info( - f"Nuking all {len(target_ids)} documents in chemsys {chemsys} in the target database." - ) - self._remove_targets(list(target_ids)) - yield {"chemsys": chemsys, "materials": all_mats_in_chemsys} - else: - yield {"chemsys": chemsys, "materials": all_mats_in_chemsys} - - def update_targets(self, items: list): - items = list(filter(None, chain.from_iterable(items))) - if len(items) > 0: - self.logger.info("Updating {} sgroups documents".format(len(items))) - for struct_group_dict in items: - struct_group_dict[self.sgroups.last_updated_field] = datetime.utcnow() - self.sgroups.update(docs=items, key=["group_id"]) - else: - self.logger.info("No items to update") - - def _entry_from_mat_doc(self, mdoc): - # Note since we are just structure grouping we don't need to be careful with energy or correction - # All of the energy analysis is left to other builders - entries = [ - ComputedStructureEntry.from_dict(v) for v in mdoc["entries"].values() - ] - if len(entries) == 1: - return entries[0] - else: - if "GGA+U" in mdoc["entries"].keys(): - return ComputedStructureEntry.from_dict(mdoc["entries"]["GGA+U"]) - elif "GGA" in mdoc["entries"].keys(): - return ComputedStructureEntry.from_dict(mdoc["entries"]["GGA"]) - else: - return None - - def process_item(self, item: Any) -> Any: - if item is None: - return None - entries = [*map(self._entry_from_mat_doc, item["materials"])] - compatibility = MaterialsProject2020Compatibility() - processed_entries = compatibility.process_entries(entries=entries) - s_groups = StructureGroupDoc.from_ungrouped_structure_entries( - entries=processed_entries, - ignored_specie=self.working_ion, - ltol=self.ltol, - stol=self.stol, - angle_tol=self.angle_tol, - ) - return [sg.model_dump() for sg in s_groups] - - def _remove_targets(self, rm_ids): - self.sgroups.remove_docs({"material_ids": {"$in": rm_ids}}) - - -class InsertionElectrodeBuilder(Builder): - def __init__( - self, - grouped_materials: MongoStore, - thermo: MongoStore, - insertion_electrode: MongoStore, - query: dict | None = None, - strip_structures: bool = False, - **kwargs, - ): - self.grouped_materials = grouped_materials - self.insertion_electrode = insertion_electrode - self.thermo = thermo - self.query = query if query else {} - self.strip_structures = strip_structures - - super().__init__( - sources=[self.grouped_materials, self.thermo], - targets=[self.insertion_electrode], - **kwargs, - ) - - def prechunk(self, number_splits: int) -> Iterator[dict]: - """ - Prechunk method to perform chunking by the key field - """ - q = dict(self.query) - - keys = self.grouped_materials.distinct(self.grouped_materials.key, criteria=q) - - N = ceil(len(keys) / number_splits) - for split in grouper(keys, N): - yield {"query": {self.grouped_materials.key: {"$in": list(split)}}} - - def get_items(self): - """ - Get items - """ - - @lru_cache(1000) - def get_working_ion_entry(working_ion): - with self.thermo as store: - working_ion_docs = [*store.query({"chemsys": working_ion})] - best_wion = min(working_ion_docs, key=lambda x: x["energy_per_atom"]) - return best_wion - - def get_thermo_docs(mat_ids): - self.logger.debug( - f"Looking for {len(mat_ids)} material_id in the Thermo DB." - ) - self.thermo.connect() - thermo_docs = list( - self.thermo.query( - {"$and": [{"material_id": {"$in": mat_ids}}]}, - properties=[ - "material_id", - "_sbxn", - "thermo", - "entries", - "energy_type", - "energy_above_hull", - ], - ) - ) - - self.logger.debug(f"Found for {len(thermo_docs)} Thermo Documents.") - if len(thermo_docs) != len(mat_ids): - missing_ids = set(mat_ids) - set( - [t_["material_id"] for t_ in thermo_docs] - ) - self.logger.warn( - f"The following ids are missing from the entries in thermo {missing_ids}.\n" - "The is likely due to the fact that a calculation other than GGA or GGA+U was " - "validated for the materials builder." - ) - return None - - # if len(item["ignored_species"]) != 1: - # raise ValueError( - # "Insertion electrode can only be defined for one working ion species" - # ) - - return thermo_docs - # return { - # "group_id": item["group_id"], - # "working_ion_doc": working_ion_doc, - # "working_ion": item["ignored_species"][0], - # "thermo_docs": thermo_docs, - # } - - q_ = {"$and": [self.query, {"has_distinct_compositions": True}]} - self.total = self.grouped_materials.count(q_) - for group_doc in self.grouped_materials.query(q_): - working_ion_doc = get_working_ion_entry(group_doc["ignored_specie"]) - thermo_docs = get_thermo_docs(group_doc["material_ids"]) - if thermo_docs: - yield { - "group_id": group_doc["group_id"], - "working_ion_doc": working_ion_doc, - "working_ion": group_doc["ignored_specie"], - "thermo_docs": thermo_docs, - } - else: - yield None - - def process_item(self, item) -> dict: - """ - - Add volume information to each entry to create the insertion electrode document - - Add the host structure - """ - if item is None: - return None # type: ignore - self.logger.debug( - f"Working on {item['group_id']} with {len(item['thermo_docs'])}" - ) - - entries = [ - tdoc_["entries"][tdoc_["energy_type"]] for tdoc_ in item["thermo_docs"] - ] - - entries = list(map(ComputedStructureEntry.from_dict, entries)) - - working_ion_entry = ComputedEntry.from_dict( - item["working_ion_doc"]["entries"][item["working_ion_doc"]["energy_type"]] - ) - - decomp_energies = { - d_["material_id"]: d_["energy_above_hull"] for d_ in item["thermo_docs"] - } - - for ient in entries: - ient.data["volume"] = ient.structure.volume - ient.data["decomposition_energy"] = decomp_energies[ - ient.data["material_id"] - ] - - ie = InsertionElectrodeDoc.from_entries( - grouped_entries=entries, - working_ion_entry=working_ion_entry, - battery_id=item["group_id"], - strip_structures=self.strip_structures, - ) - if ie is None: - return None # type: ignore - # {"failed_reason": "unable to create InsertionElectrode document"} - return jsanitize(ie.model_dump()) - - def update_targets(self, items: list): - items = list(filter(None, items)) - if len(items) > 0: - self.logger.info("Updating {} battery documents".format(len(items))) - for struct_group_dict in items: - struct_group_dict[self.grouped_materials.last_updated_field] = ( - datetime.utcnow() - ) - self.insertion_electrode.update(docs=items, key=["battery_id"]) - else: - self.logger.info("No items to update") - - -class ConversionElectrodeBuilder(Builder): - def __init__( - self, - phase_diagram_store: MongoStore, - conversion_electrode_store: MongoStore, - working_ion: str, - thermo_type: str, - query: dict | None = None, - **kwargs, - ): - self.phase_diagram_store = phase_diagram_store - self.conversion_electrode_store = conversion_electrode_store - self.working_ion = working_ion - self.thermo_type = thermo_type - self.query = query if query else {} - self.kwargs = kwargs - - self.phase_diagram_store.key = "phase_diagram_id" - self.conversion_electrode_store.key = "conversion_electrode_id" - - super().__init__( - sources=[self.phase_diagram_store], - targets=[self.conversion_electrode_store], - **kwargs, - ) - - def prechunk(self, number_splits: int) -> Iterator[dict]: - """ - Prechunk method to perform chunking by the key field - """ - q = dict(self.query) - keys = self.phase_diagram_store.distinct( - self.phase_diagram_store.key, criteria=q - ) - N = ceil(len(keys) / number_splits) - for split in grouper(keys, N): - yield {"query": {self.phase_diagram_store.key: {"$in": list(split)}}} - - def get_items(self): - """ - Get items. Phase diagrams are filtered such that only PDs with chemical systems containing - the working ion and the specified "thermo_type", or functional, are chosen. - """ - - all_chemsys = self.phase_diagram_store.distinct("chemsys") - - chemsys_w_wion = [c for c in all_chemsys if self.working_ion in c] - - q = { - "$and": [ - dict(self.query), - {"thermo_type": self.thermo_type}, - {"chemsys": {"$in": chemsys_w_wion}}, - ] - } - - for phase_diagram_doc in self.phase_diagram_store.query(criteria=q): - yield phase_diagram_doc - - def process_item(self, item) -> dict: - """ - - For each phase diagram doc, find all possible conversion electrodes and create conversion electrode docs - """ - # To work around "el_refs" serialization issue (#576) - _pd = PhaseDiagram.from_dict(item["phase_diagram"]) - _entries = _pd.all_entries - pd = PhaseDiagram(entries=_entries) - - most_wi = defaultdict(lambda: (-1, None)) # type: dict - n_elements = pd.dim - # Only using entries on convex hull for now - for entry in pd.stable_entries: - if len(entry.composition.elements) != n_elements: - continue - composition_dict = entry.composition.as_dict() - composition_dict.pop(self.working_ion) - composition_without_wi = Composition.from_dict(composition_dict) - red_form, num_form = composition_without_wi.get_reduced_formula_and_factor() - n_wi = entry.composition.get_el_amt_dict()[self.working_ion] - most_wi[red_form] = max( - most_wi[red_form], (n_wi / num_form, entry.composition) - ) - - new_docs = [] - unique_reaction_compositions = set() - reaction_compositions = [] - for k, v in most_wi.items(): - if v[1] is not None: - # Get lowest material_id with matching composition - material_ids = [ - ( - lambda x: ( - x.data["material_id"] # type: ignore[attr-defined] - if x.composition.reduced_formula == v[1].reduced_formula - else None - ) - )(e) - for e in pd.entries - ] - material_ids = list(filter(None, material_ids)) - lowest_id = min(material_ids, key=_get_id_lexi) # type: ignore[arg-type] - conversion_electrode_doc = ( - ConversionElectrodeDoc.from_composition_and_pd( - comp=v[1], - pd=pd, - working_ion_symbol=self.working_ion, - battery_id=f"{lowest_id}_{self.working_ion}", - thermo_type=self.thermo_type, - ) - ) - # Get reaction entry_ids - comps = set() - for c in conversion_electrode_doc.reaction["reactants"].keys(): - comps.add(c) - unique_reaction_compositions.add(c) - for c in conversion_electrode_doc.reaction["products"].keys(): - comps.add(c) - unique_reaction_compositions.add(c) - reaction_compositions.append(comps) - new_docs.append(jsanitize(conversion_electrode_doc.model_dump())) - - entry_id_mapping = {} - for c in unique_reaction_compositions: - relevant_entry_data = [] - for e in pd.entries: - if e.composition == Composition(c): - relevant_entry_data.append((e.energy_per_atom, e.entry_id)) # type: ignore[attr-defined] - relevant_entry_data.sort(key=lambda x: x[0]) - entry_id_mapping[c] = relevant_entry_data[0][1] - - for i, comps in enumerate(reaction_compositions): - mapping = {} - for c in comps: - mapping[c] = entry_id_mapping[c] - new_docs[i]["formula_id_mapping"] = mapping - - return new_docs # type: ignore - - def update_targets(self, items: list): - combined_items = [] - for _items in items: - _items = list(filter(None, _items)) - combined_items.extend(_items) - - if len(combined_items) > 0: - self.logger.info( - "Updating {} conversion battery documents".format(len(combined_items)) - ) - self.conversion_electrode_store.update( - docs=combined_items, key=["battery_id", "thermo_type"] - ) - else: - self.logger.info("No items to update") diff --git a/emmet-builders/emmet/builders/materials/electronic_structure.py b/emmet-builders/emmet/builders/materials/electronic_structure.py index 522abf9a61..25c77a5738 100644 --- a/emmet-builders/emmet/builders/materials/electronic_structure.py +++ b/emmet-builders/emmet/builders/materials/electronic_structure.py @@ -1,766 +1,295 @@ -import itertools -import re -import warnings -from collections import defaultdict -from math import ceil - -import boto3 -import numpy as np -from botocore.handlers import disable_signing -from maggma.builders import Builder -from maggma.utils import grouper -from pymatgen.analysis.magnetism.analyzer import CollinearMagneticStructureAnalyzer -from pymatgen.analysis.structure_matcher import StructureMatcher -from pymatgen.core import Structure -from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine -from pymatgen.electronic_structure.core import Spin -from pymatgen.electronic_structure.dos import CompleteDos -from pymatgen.io.vasp.sets import MPStaticSet -from pymatgen.symmetry.analyzer import SpacegroupAnalyzer - -from emmet.builders.utils import query_open_data -from emmet.core.band_theory import get_path_from_bandstructure, obtain_path_type +from dataclasses import dataclass +from datetime import datetime +from enum import Enum, auto +from functools import update_wrapper + +from pydantic import BaseModel +from pymatgen.analysis.magnetism.analyzer import Ordering + +from emmet.builders.utils import filter_map from emmet.core.electronic_structure import ElectronicStructureDoc -from emmet.core.utils import jsanitize - -warnings.warn( - f"The current version of {__name__}.ElectronicStructureBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class ElectronicStructureBuilder(Builder): - def __init__( - self, - tasks, - materials, - electronic_structure, - bandstructure_fs, - dos_fs, - chunk_size=10, - query=None, - **kwargs, - ): - """ - Creates an electronic structure collection from a tasks collection, - the associated band structures and density of states file store collections, - and the materials collection. - - Individual bandstructures for each of the three conventions are generated. - - tasks (Store): Store of task documents - materials (Store): Store of materials documents - electronic_structure (Store): Store of electronic structure summary data documents - bandstructure_fs (Store, str): Store of bandstructures, or S3 URL string with prefix - (e.g. s3://materialsproject-parsed/bandstructures). - dos_fs (Store, str): Store of DOS, or S3 URL string with bucket and prefix - (e.g. s3://materialsproject-parsed/dos). - chunk_size (int): Chunk size to use for processing. Defaults to 10. - query (dict): Dictionary to limit materials to be analyzed - """ - - self.tasks = tasks - self.materials = materials - self.electronic_structure = electronic_structure - self.bandstructure_fs = bandstructure_fs - self.dos_fs = dos_fs - self.chunk_size = chunk_size - self.query = query if query else {} - - self._s3_resource = None - - sources = [tasks, materials] - - fs_stores = [bandstructure_fs, dos_fs] - - for store in fs_stores: - if isinstance(store, str): - if not re.match("^s3://.*", store): - raise ValueError( - "Please provide an S3 URL " - "in the format s3://{bucket_name}/{prefix}" - ) - - if self._s3_resource is None: - self._s3_resource = boto3.resource("s3") - self._s3_resource.meta.client.meta.events.register( - "choose-signer.s3.*", disable_signing - ) - - else: - sources.append(store) - - super().__init__( - sources=sources, - targets=[electronic_structure], - chunk_size=chunk_size, - **kwargs, - ) +from emmet.core.material import PropertyOrigin +from emmet.core.types.electronic_structure import BSShim, DosShim +from emmet.core.types.pymatgen_types.structure_adapter import StructureType +from emmet.core.types.typing import IdentifierType - def prechunk(self, number_splits: int): # pragma: no cover - """ - Prechunk method to perform chunking by the key field - """ - q = dict(self.query) - keys = self.electronic_structure.newer_in( - self.materials, criteria=q, exhaustive=True - ) +class InputMeta(BaseModel): + # structure metadata + deprecated: bool = False + material_id: IdentifierType | None = None + meta_structure: StructureType + origins: list[PropertyOrigin] = [] + warnings: list[str] = [] - N = ceil(len(keys) / number_splits) - for split in grouper(keys, N): - yield {"query": {self.materials.key: {"$in": list(split)}}} - def get_items(self): - """ - Gets all items to process +class StructureInputs(InputMeta): + # summary electronic structure data from VASP outputs for task doc + band_gap: float + cbm: float | None = None + vbm: float | None = None + efermi: float | None = None + is_gap_direct: bool + is_metal: bool + magnetic_ordering: Ordering - Returns: - generator or list relevant tasks and materials to process - """ - self.logger.info("Electronic Structure Builder Started") +class StructuresShim(InputMeta): + # map of structures with task_ids -> used in post doc build checks + structures: dict[IdentifierType, StructureType] - q = dict(self.query) - mat_ids = self.materials.distinct(self.materials.key, criteria=q) - es_ids = self.electronic_structure.distinct(self.electronic_structure.key) +class BSInputs(StructuresShim): + bandstructures: BSShim - mats_set = set( - self.electronic_structure.newer_in( - target=self.materials, criteria=q, exhaustive=True - ) - ) | (set(mat_ids) - set(es_ids)) - mats = [mat for mat in mats_set] +class DosInputs(StructuresShim): + dos: DosShim + is_gap_direct: bool - self.logger.info( - "Processing {} materials for electronic structure".format(len(mats)) - ) - self.total = len(mats) +class BSDosInputs(DosInputs, BSInputs): ... - for mat in mats: - mat = self._update_materials_doc(mat) - yield mat - def process_item(self, mat): - """ - Process the band structures and dos data. +class Variant(Enum): + STRUCTURE = auto() + BS = auto() + DOS = auto() + BS_DOS = auto() - Args: - mat (dict): material document - Returns: - (dict): electronic_structure document - """ +InputData = list[StructureInputs] | list[BSInputs] | list[DosInputs] | list[BSDosInputs] +"""Tagged union for valid input types for build_electronic_structure_docs.""" - structure = Structure.from_dict(mat["structure"]) - self.logger.info("Processing: {}".format(mat[self.materials.key])) +@dataclass +class ESBuilderInput: + """ + Container for electronic structure builder inputs that pairs + a Variant tag with the corresponding input data. - dos = None - bs = {} - structures = {} + The variant field determines which construction path + build_electronic_structure_docs will use to produce + ElectronicStructureDoc instances. - for bs_type, bs_entry in mat["bandstructure"].items(): - if bs_entry.get("object", None) is not None: - bs[bs_type] = ( - { - bs_entry["task_id"]: BandStructureSymmLine.from_dict( - bs_entry["object"] - ) - } - if bs_entry - else None - ) + The data field holds a list of the appropriate input model + (StructureInputs, BSInputs, DosInputs, or BSDosInputs) matching + the chosen variant. Callers are responsible for populating this + list within their own data pipeline context; helper functions + such as ``obtain_blessed_dos`` and ``obtain_blessed_bs`` can assist + in selecting the best candidate calculations for inclusion. + """ - structures[bs_entry["task_id"]] = bs_entry["output_structure"] + variant: Variant + data: InputData - if mat["dos"]: - if mat["dos"]["object"] is not None: - self.logger.info("Processing density of states") - dos = { - mat["dos"]["task_id"]: CompleteDos.from_dict(mat["dos"]["object"]) - } - structures[mat["dos"]["task_id"]] = mat["dos"]["output_structure"] +def variant_dispatch(func): + """ + Slight remix to functools.singledispatch to perform dynamic + dispatch based on the type of an enum variant for arg, rather + than arg itself. - if bs: - self.logger.info( - "Processing band structure types: {}".format( - [bs_type for bs_type, bs_entry in bs.items() if bs_entry] - ) - ) + Only usable with objects with a ``variant`` enum attr. See: ESBuilderInput + """ + registry = {} - # Default summary data - d = dict( - material_id=mat[self.materials.key], - deprecated=mat["deprecated"], - task_id=mat["other"]["task_id"], - meta_structure=structure, - band_gap=mat["other"]["band_gap"], - cbm=mat["other"]["cbm"], - vbm=mat["other"]["vbm"], - efermi=mat["other"]["efermi"], - is_gap_direct=mat["other"]["is_gap_direct"], - is_metal=mat["other"]["is_metal"], - magnetic_ordering=mat["other"]["magnetic_ordering"], - origins=mat["origins"], - warnings=[], - ) + def register(variant_value): + def decorator(f): + registry[variant_value] = f + return f - # Eigenvalue band property checks - eig_values = mat["other"].get("eigenvalue_band_properties", None) - - if eig_values is not None: - if not np.isclose( - mat["other"]["band_gap"], eig_values["bandgap"], atol=0.2, rtol=0.0 - ): - d["warnings"].append( - "Regular parsed band gap and band gap from eigenvalue_band_properties do not agree. " - "Using data from eigenvalue_band_properties where appropriate." - ) - - d["band_gap"] = eig_values["bandgap"] - d["cbm"] = eig_values["cbm"] - d["vbm"] = eig_values["vbm"] - d["is_gap_direct"] = eig_values["is_gap_direct"] - d["is_metal"] = ( - True if np.isclose(d["band_gap"], 0.0, atol=0.01, rtol=0) else False - ) - - if dos is None: - doc = ElectronicStructureDoc.from_structure(**d) - - else: - try: - doc = ElectronicStructureDoc.from_bsdos( - material_id=mat[self.materials.key], - structures=structures, - dos=dos, - is_gap_direct=d["is_gap_direct"], - is_metal=d["is_metal"], - deprecated=d["deprecated"], - origins=d["origins"], - **bs, - ) - doc = self._bsdos_checks(doc, dos[mat["dos"]["task_id"]], structures) - - except Exception: - d["warnings"].append( - "Band structure and/or data exists but an error occured while processing." - ) - doc = ElectronicStructureDoc.from_structure(**d) - - # Magnetic ordering check - mag_orderings = {} - if doc.bandstructure is not None: - mag_orderings.update( - { - bs_summary.task_id: bs_summary.magnetic_ordering - for bs_type, bs_summary in doc.bandstructure - if bs_summary is not None - } - ) + return decorator - if doc.dos is not None: - dos_dict = doc.dos.model_dump() - mag_orderings.update( - {dos_dict["total"][Spin.up]["task_id"]: dos_dict["magnetic_ordering"]} - ) + def wrapper(inputs, **kwargs): + variant = inputs.variant + if variant not in registry: + raise NotImplementedError(f"No handler registered for {variant!r}") + return registry[variant](inputs, **kwargs) - for task_id, ordering in mag_orderings.items(): - if doc.magnetic_ordering != ordering: - doc.warnings.append( - f"Summary data magnetic ordering does not agree with the ordering from {task_id}" - ) - - # LMAXMIX check, VASP default is 2 - expected_lmaxmix = MPStaticSet(structure).incar.get("LMAXMIX", 2) - if mat["dos"] and mat["dos"]["lmaxmix"] != expected_lmaxmix: - doc.warnings.append( - "An incorrect calculation parameter may lead to errors in the band gap of " - f"0.1-0.2 eV (LMAXIX is {mat['dos']['lmaxmix']} and should be {expected_lmaxmix} for " - f"{mat['dos']['task_id']}). A correction calculation is planned." - ) + wrapper.register = register + update_wrapper(wrapper, func) + return wrapper - for bs_type, bs_entry in mat["bandstructure"].items(): - if bs_entry["lmaxmix"] != expected_lmaxmix: - doc.warnings.append( - "An incorrect calculation parameter may lead to errors in the band gap of " - f"0.1-0.2 eV (LMAXIX is {bs_entry['lmaxmix']} and should be {expected_lmaxmix} for " - f"{bs_entry['task_id']}). A correction calculation is planned." - ) - - return doc.model_dump() - - def update_targets(self, items): - """ - Inserts electronic structure documents into the electronic_structure collection - - Args: - items ([dict]): A list of ElectronicStructureDoc dictionaries to update - """ - - items = list(filter(None, items)) - - if len(items) > 0: - self.logger.info("Updating {} electronic structure docs".format(len(items))) - self.electronic_structure.update(docs=jsanitize(items, allow_bson=True)) - else: - self.logger.info("No electronic structure docs to update") - - def _bsdos_checks(self, doc, dos, structures): - # Band gap difference check for uniform and line-mode calculations - bgap_diff = [] - for bs_type, bs_summary in doc.bandstructure: - if bs_summary is not None: - bgap_diff.append(doc.band_gap - bs_summary.band_gap) - - if dos is not None: - bgap_diff.append(doc.band_gap - dos.get_gap()) - - if any(abs(gap) > 0.25 for gap in bgap_diff): - if doc.warnings is None: - doc.warnings = [] - doc.warnings.append( - "Absolute difference between blessed band gap and at least one " - "line-mode or uniform calculation band gap is larger than 0.25 eV." - ) - # Line-mode and uniform structure primitive checks +@variant_dispatch +def build_electronic_structure_docs( + inputs: ESBuilderInput, **kwargs +) -> list[ElectronicStructureDoc]: + """ + Generate electronic structure documents from tagged input data. - pair_list = [] - for task_id, struct in structures.items(): - pair_list.append((task_id, struct)) + Dispatches on the variant field of the provided ESBuilderInput to + construct ElectronicStructureDoc instances via the appropriate + factory method (from_structure, from_bs, from_dos, or from_bsdos). - struct_prim = SpacegroupAnalyzer(struct).get_primitive_standard_structure( - international_monoclinic=False - ) + Caller is responsible for creating ESBuilderInput instances + within their data pipeline context. + + Args: + inputs: An ESBuilderInput whose variant selects the + construction path and whose data contains the + corresponding list of input documents to process. + + Returns: + list[ElectronicStructureDoc] + """ - if not np.allclose( - struct.lattice.matrix, struct_prim.lattice.matrix, atol=1e-3 - ): - if doc.warnings is None: - doc.warnings = [] - - if np.isclose(struct_prim.volume, struct.volume, atol=5, rtol=0): - doc.warnings.append( - f"The input structure for {task_id} is primitive but may not exactly match the " - f"standard primitive setting." - ) - else: - doc.warnings.append( - f"The input structure for {task_id} does not match the expected standard primitive" - ) - - # Check line-mode and uniform for same structure - sm = StructureMatcher() - for pair in itertools.combinations(pair_list, 2): - if not sm.fit(pair[0][1], pair[1][1]): - if doc.warnings is None: - doc.warnings = [] - - doc.warnings.append( - f"The input structures between bandstructure calculations {pair[0][0]} and {pair[1][0]} " - f"are not equivalent" - ) - - return doc - - def _update_materials_doc(self, mat_id): - # find bs type for each task in task_type and store each different bs object - - mat = self.materials.query_one( - properties=[ - self.materials.key, - "structure", - "inputs", - "task_types", + +@build_electronic_structure_docs.register(Variant.STRUCTURE) +def _(inputs: ESBuilderInput, **kwargs) -> list[ElectronicStructureDoc]: + return list( + filter_map( + ElectronicStructureDoc.from_structure, + inputs.data, + work_keys=[ "deprecated", - self.materials.last_updated_field, + "material_id", + "meta_structure", + "origins", + "warnings", + "band_gap", + "cbm", + "vbm", + "efermi", + "is_gap_direct", + "is_metal", + "magnetic_ordering", ], - criteria={self.materials.key: mat_id}, + **kwargs, ) + ) + + +@build_electronic_structure_docs.register(Variant.BS) +def _(inputs: ESBuilderInput, **kwargs) -> list[ElectronicStructureDoc]: + return list( + filter_map( + ElectronicStructureDoc.from_bs, + inputs.data, + work_keys=[ + "bandstructures", + "origins", + "structures", + # PropertyDoc.from_structure(...) kwargs + "deprecated", + "material_id", + "meta_structure", + ], + **kwargs, + ) + ) + - mat["dos"] = {} - mat["bandstructure"] = defaultdict(dict) - mat["other"] = {} - - bs_calcs = defaultdict(list) - dos_calcs = [] - other_calcs = [] - - for task_id in mat["task_types"].keys(): - # Handle all line-mode tasks - if "NSCF Line" in mat["task_types"][task_id]: - bs_type = None - - task_query = self.tasks.query_one( - properties=[ - "calcs_reversed", - "last_updated", - "input.is_hubbard", - "input.incar", - "orig_inputs.kpoints", - "input.parameters", - "output.structure", - ], - criteria={"task_id": str(task_id)}, - ) - - fs_id = str( - task_query["calcs_reversed"][0].get("bandstructure_fs_id", None) - ) - - if fs_id is not None: - structure = Structure.from_dict(task_query["output"]["structure"]) - - kpoints = task_query["orig_inputs"]["kpoints"] - - labels_dict = { - label: point - for label, point in zip(kpoints["labels"], kpoints["kpoints"]) - if label is not None - } - - try: - bs_type = next( - obtain_path_type( - labels_dict, - structure, - [label for label in kpoints["labels"] if label], - ) - ) - except Exception: - bs_type = None - - if bs_type is None: - if isinstance(self.bandstructure_fs, str): - _, _, bucket, prefix = self.bandstructure_fs.strip( - "/" - ).split("/") - - bs_dict = query_open_data( - bucket, - prefix, - task_id, - monty_decode=False, - s3_resource=self._s3_resource, - ) - else: - bs_dict = self.bandstructure_fs.query_one( - {self.bandstructure_fs.key: str(task_id)} - ) - - if bs_dict is not None: - bs = BandStructureSymmLine.from_dict(bs_dict["data"]) - - labels_dict = { - label: kpoint.frac_coords - for label, kpoint in bs.labels_dict.items() - } - - try: - bs_type = next( - obtain_path_type( - labels_dict, - bs.structure, - get_path_from_bandstructure(bs), - ) - ) - except Exception: - bs_type = None - - # Clear bs data - bs = None - bs_dict = None - - is_hubbard = task_query["input"]["is_hubbard"] - lmaxmix = task_query["input"]["incar"].get( - "LMAXMIX", 2 - ) # VASP default is 2, alternatively could project `parameters` - nkpoints = task_query["orig_inputs"]["kpoints"]["nkpoints"] - lu_dt = task_query["last_updated"] - - if bs_type is not None: - bs_calcs[bs_type].append( - { - "fs_id": fs_id, - "task_id": task_id, - "is_hubbard": int(is_hubbard), - "lmaxmix": lmaxmix, - "nkpoints": int(nkpoints), - "updated_on": lu_dt, - "output_structure": structure, - "labels_dict": labels_dict, - } - ) - - # Handle uniform tasks - if "NSCF Uniform" in mat["task_types"][task_id]: - task_query = self.tasks.query_one( - properties=[ - "calcs_reversed", - "last_updated", - "input.is_hubbard", - "input.incar", - "orig_inputs.kpoints", - "input.parameters", - "output.structure", - ], - criteria={"task_id": str(task_id)}, - ) - - fs_id = str(task_query["calcs_reversed"][0].get("dos_fs_id", None)) - - if fs_id is not None: - lmaxmix = task_query["input"]["incar"].get( - "LMAXMIX", 2 - ) # VASP default is 2, alternatively could project `parameters` - - is_hubbard = task_query["input"]["is_hubbard"] - - structure = Structure.from_dict(task_query["output"]["structure"]) - - if ( - task_query["orig_inputs"]["kpoints"]["generation_style"] - == "Monkhorst" - or task_query["orig_inputs"]["kpoints"]["generation_style"] - == "Gamma" - ): - nkpoints = np.prod( - task_query["orig_inputs"]["kpoints"]["kpoints"][0], axis=0 - ) - - else: - nkpoints = task_query["orig_inputs"]["kpoints"]["nkpoints"] - - nedos = task_query["input"]["parameters"]["NEDOS"] - lu_dt = task_query["last_updated"] - - dos_calcs.append( - { - "fs_id": fs_id, - "task_id": task_id, - "is_hubbard": int(is_hubbard), - "lmaxmix": lmaxmix, - "nkpoints": int(nkpoints), - "nedos": int(nedos), - "updated_on": lu_dt, - "output_structure": structure, - } - ) - - # Handle static and structure opt tasks - if "Static" or "Structure Optimization" in mat["task_types"][task_id]: - task_query = self.tasks.query_one( - properties=[ - "last_updated", - "input.is_hubbard", - "orig_inputs.kpoints", - "calcs_reversed", - "output.structure", - ], - criteria={"task_id": str(task_id)}, - ) - - structure = Structure.from_dict(task_query["output"]["structure"]) - - other_mag_ordering = CollinearMagneticStructureAnalyzer( - structure - ).ordering - - is_hubbard = task_query["input"]["is_hubbard"] - - last_calc = task_query["calcs_reversed"][-1] - - if ( - last_calc["input"]["kpoints"]["generation_style"] == "Monkhorst" - or last_calc["input"]["kpoints"]["generation_style"] == "Gamma" - ): - nkpoints = np.prod( - last_calc["input"]["kpoints"]["kpoints"][0], axis=0 - ) - else: - nkpoints = last_calc["input"]["kpoints"]["nkpoints"] - - lu_dt = task_query["last_updated"] - - other_calcs.append( - { - "is_static": ( - True if "Static" in mat["task_types"][task_id] else False - ), - "task_id": task_id, - "is_hubbard": int(is_hubbard), - "nkpoints": int(nkpoints), - "magnetic_ordering": other_mag_ordering, - "updated_on": lu_dt, - "calcs_reversed": task_query["calcs_reversed"], - } - ) - - updated_materials_doc = self._obtain_blessed_calculations( - mat, bs_calcs, dos_calcs, other_calcs +@build_electronic_structure_docs.register(Variant.DOS) +def _(inputs: ESBuilderInput, **kwargs) -> list[ElectronicStructureDoc]: + return list( + filter_map( + ElectronicStructureDoc.from_dos, + inputs.data, + work_keys=[ + "dos", + "is_gap_direct", + "origins", + "structures", + # PropertyDoc.from_structure(...) kwargs + "deprecated", + "material_id", + "meta_structure", + ], + **kwargs, + ) + ) + + +@build_electronic_structure_docs.register(Variant.BS_DOS) +def _(inputs: ESBuilderInput, **kwargs) -> list[ElectronicStructureDoc]: + return list( + filter_map( + ElectronicStructureDoc.from_bsdos, + inputs.data, + work_keys=[ + "bandstructures", + "dos", + "origins", + "structures", + # PropertyDoc.from_structure(...) kwargs + "deprecated", + "material_id", + "meta_structure", + ], + **kwargs, ) + ) - return updated_materials_doc - - def _obtain_blessed_calculations( - self, materials_doc, bs_calcs, dos_calcs, other_calcs - ): - bs_types = ["setyawan_curtarolo", "hinuma", "latimer_munro"] - - materials_doc["origins"] = [] - - for bs_type in bs_types: - # select "blessed" bs of each type - if bs_calcs[bs_type]: - sorted_bs_data = sorted( - bs_calcs[bs_type], - key=lambda entry: ( - entry["is_hubbard"], - entry["nkpoints"], - entry["updated_on"], - ), - reverse=True, - ) - - materials_doc["bandstructure"][bs_type]["task_id"] = sorted_bs_data[0][ - "task_id" - ] - - materials_doc["bandstructure"][bs_type]["lmaxmix"] = sorted_bs_data[0][ - "lmaxmix" - ] - if isinstance(self.bandstructure_fs, str): - _, _, bucket, prefix = self.bandstructure_fs.strip("/").split("/") - bs_obj = query_open_data( - bucket, - prefix, - sorted_bs_data[0]["task_id"], - monty_decode=False, - s3_resource=self._s3_resource, - ) - else: - bs_obj = self.bandstructure_fs.query_one( - criteria={"fs_id": sorted_bs_data[0]["fs_id"]} - ) - - materials_doc["bandstructure"][bs_type]["object"] = ( - bs_obj["data"] if bs_obj is not None else None - ) - - materials_doc["bandstructure"][bs_type]["output_structure"] = ( - sorted_bs_data[0]["output_structure"] - ) - - materials_doc["origins"].append( - { - "name": bs_type, - "task_id": sorted_bs_data[0]["task_id"], - "last_updated": sorted_bs_data[0]["updated_on"], - } - ) - - if dos_calcs: - sorted_dos_data = sorted( - dos_calcs, - key=lambda entry: ( - entry["is_hubbard"], - entry["nkpoints"], - entry["nedos"], - entry["updated_on"], - ), - reverse=True, - ) - materials_doc["dos"]["task_id"] = sorted_dos_data[0]["task_id"] - - materials_doc["dos"]["lmaxmix"] = sorted_dos_data[0]["lmaxmix"] - - if isinstance(self.bandstructure_fs, str): - _, _, bucket, prefix = self.dos_fs.strip("/").split("/") - dos_obj = query_open_data( - bucket, - prefix, - sorted_dos_data[0]["task_id"], - monty_decode=False, - s3_resource=self._s3_resource, - ) - else: - dos_obj = self.dos_fs.query_one( - criteria={"fs_id": sorted_dos_data[0]["fs_id"]} - ) - - materials_doc["dos"]["object"] = ( - dos_obj["data"] if dos_obj is not None else None - ) +# ----------------------------------------------------------------------------- +# Helper funcs + types +# ----------------------------------------------------------------------------- - materials_doc["dos"]["output_structure"] = sorted_dos_data[0][ - "output_structure" - ] - materials_doc["origins"].append( - { - "name": "dos", - "task_id": sorted_dos_data[0]["task_id"], - "last_updated": sorted_dos_data[0]["updated_on"], - } - ) +class BaseCalcInfo(BaseModel): + """Basic struct of metadata for use in sorting a list of candidate blessed calculations.""" - if other_calcs: - sorted_other_data = sorted( - other_calcs, - key=lambda entry: ( - entry["is_static"], - entry["is_hubbard"], - entry["nkpoints"], - entry["updated_on"], - ), - reverse=True, - ) + task_id: str + is_hubbard: bool | None + lmaxmix: int | None + nkpoints: int | None + last_updated: datetime - materials_doc["other"]["task_id"] = str(sorted_other_data[0]["task_id"]) - task_output_data = sorted_other_data[0]["calcs_reversed"][-1]["output"] - materials_doc["other"]["band_gap"] = task_output_data["bandgap"] - materials_doc["other"]["magnetic_ordering"] = sorted_other_data[0][ - "magnetic_ordering" - ] - materials_doc["other"]["last_updated"] = sorted_other_data[0]["updated_on"] +class DosCalc(BaseCalcInfo): + nedos: int | None - materials_doc["other"]["is_metal"] = ( - materials_doc["other"]["band_gap"] == 0.0 - ) - materials_doc["origins"].append( - { - "name": "electronic_structure", - "task_id": sorted_other_data[0]["task_id"], - "last_updated": sorted_other_data[0]["updated_on"], - } +class BSCalc(BaseCalcInfo): ... + + +def obtain_blessed_dos(dos_calcs: list[DosCalc]) -> DosCalc: + """ + Yields best dos calc from list of dos calcs. + + Helpful for preparing ``ESBuilderInput`` for ``build_electronic_structure_docs`` + """ + sorted_dos_data = sorted( + dos_calcs, + key=lambda entry: ( + entry.is_hubbard, + entry.nkpoints, + entry.nedos, + entry.last_updated, + ), + reverse=True, + ) + return sorted_dos_data[0] + + +def obtain_blessed_bs(bs_calcs: dict[str, list[BSCalc]]) -> dict[str, BSCalc]: + """ + Yields map of best bs calc per path convention from map of lists of + bs calcs for each path convention. + + Helpful for preparing ``ESBuilderInput`` for ``build_electronic_structure_docs`` + """ + blessed_entries = {} + bs_types = ["setyawan_curtarolo", "hinuma", "latimer_munro"] + for bs_type in bs_types: + if bs_calcs.get(bs_type): + sorted_bs_data = sorted( + [entry for entry in bs_calcs[bs_type] if entry is not None], + key=lambda entry: ( + # Entries with any None sort last (False < True, reversed) + entry.is_hubbard is not None + and entry.nkpoints is not None + and entry.last_updated is not None, + entry.is_hubbard or False, + entry.nkpoints or 0, + entry.last_updated or datetime.min, + ), + reverse=True, ) - for prop in [ - "efermi", - "cbm", - "vbm", - "is_gap_direct", - "is_metal", - "eigenvalue_band_properties", - ]: - # First try other calcs_reversed entries if properties are not found in last - if prop not in task_output_data: - for calc in sorted_other_data[0]["calcs_reversed"]: - if calc["output"].get(prop, None) is not None: - materials_doc["other"][prop] = calc["output"][prop] - else: - materials_doc["other"][prop] = task_output_data[prop] - - return materials_doc + if sorted_bs_data: + blessed_entries[bs_type] = sorted_bs_data[0] + + return blessed_entries diff --git a/emmet-builders/emmet/builders/materials/magnetism.py b/emmet-builders/emmet/builders/materials/magnetism.py index 3ccfa0830a..9cfe4aac3e 100644 --- a/emmet-builders/emmet/builders/materials/magnetism.py +++ b/emmet-builders/emmet/builders/materials/magnetism.py @@ -1,171 +1,45 @@ -from __future__ import annotations - -import warnings -from math import ceil -from typing import TYPE_CHECKING - -from maggma.builders import Builder -from maggma.stores import Store -from maggma.utils import grouper -from pymatgen.core.structure import Structure - +from emmet.builders.base import BaseBuilderInput +from emmet.builders.utils import filter_map from emmet.core.magnetism import MagnetismDoc -from emmet.core.mpid import AlphaID -from emmet.core.utils import jsanitize - -if TYPE_CHECKING: - from collections.abc import Iterator - -__author__ = "Shyam Dwaraknath , Matthew Horton " - - -warnings.warn( - f"The current version of {__name__}.MagneticBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class MagneticBuilder(Builder): - def __init__( - self, - materials: Store, - magnetism: Store, - tasks: Store, - query: dict | None = None, - **kwargs, - ): - """ - Creates a magnetism collection for materials - - Args: - materials (Store): Store of materials documents to match to - magnetism (Store): Store of magnetism properties - - """ - - self.materials = materials - self.magnetism = magnetism - self.tasks = tasks - self.query = query or {} - self.kwargs = kwargs - - self.materials.key = "material_id" - self.tasks.key = "task_id" - self.magnetism.key = "material_id" - - super().__init__(sources=[materials, tasks], targets=[magnetism], **kwargs) - - def prechunk(self, number_splits: int) -> Iterator[dict]: # pragma: no cover - """ - Prechunk method to perform chunking by the key field - """ - q = dict(self.query) - - q.update({"deprecated": False}) - - keys = self.magnetism.newer_in(self.materials, criteria=q, exhaustive=True) - - N = ceil(len(keys) / number_splits) - for split in grouper(keys, N): - yield {"query": {self.materials.key: {"$in": list(split)}}} - - def get_items(self): - """ - Gets all items to process - - Returns: - Generator or list relevant tasks and materials to process - """ - - self.logger.info("Magnetism Builder Started") - - q = dict(self.query) - - q.update({"deprecated": False}) - - mat_ids = self.materials.distinct(self.materials.key, criteria=q) - mag_ids = self.magnetism.distinct(self.magnetism.key) - - mats_set = set( - self.magnetism.newer_in(target=self.materials, criteria=q, exhaustive=True) - ) | (set(mat_ids) - set(mag_ids)) - - mats = [mat for mat in mats_set] - - self.logger.info("Processing {} materials for magnetism data".format(len(mats))) - - self.total = len(mats) - - for mat in mats: - doc = self._get_processed_doc(mat) - - if doc is not None: - yield doc - else: - pass - - def process_item(self, item): - structure = Structure.from_dict(item["structure"]) - mpid = item["material_id"] - origin_entry = { - "name": "magnetism", - "task_id": item["task_id"], - "last_updated": item["task_updated"], - } - - doc = MagnetismDoc.from_structure( - structure=structure, - material_id=mpid, - total_magnetization=item["total_magnetization"], - origins=[origin_entry], - deprecated=item["deprecated"], - last_updated=item["last_updated"], - ) - - return jsanitize(doc.model_dump(), allow_bson=True) - - def update_targets(self, items): - """ - Inserts the new magnetism docs into the magnetism collection - """ - docs = list(filter(None, items)) - - if len(docs) > 0: - self.logger.info(f"Found {len(docs)} magnetism docs to update") - self.magnetism.update(docs) - else: - self.logger.info("No items to update") - - def _get_processed_doc(self, mat): - mat_doc = self.materials.query_one( - {self.materials.key: mat}, - [self.materials.key, "origins", "last_updated", "structure", "deprecated"], +from emmet.core.material import PropertyOrigin + + +class MagnetismBuilderInput(BaseBuilderInput): + total_magnetization: float + origins: list[PropertyOrigin] + + +def build_magnetism_docs( + input_documents: list[MagnetismBuilderInput], **kwargs +) -> list[MagnetismDoc]: + """ + Generate magnetism documents from input structures. + + Transforms a list of MagnetismBuilderInput documents containing + Pymatgen structures into corresponding MagnetismDoc instances by + analyzing the magnetic configuration of each structure. + + Caller is responsible for creating MagnetismBuilderInput instances + within their data pipeline context. + + Args: + input_documents: List of MagnetismBuilderInput documents to process. + + Returns: + list[MagnetismDoc] + """ + + return list( + filter_map( + MagnetismDoc.from_structure, + input_documents, + work_keys=[ + "deprecated", + "material_id", + "structure", + "origins", + "total_magnetization", + ], + **kwargs ) - - for origin in mat_doc["origins"]: - if origin["name"] == "structure": - task_id = str(AlphaID(origin["task_id"])) - - task_query = self.tasks.query_one( - properties=["last_updated", "calcs_reversed"], - criteria={self.tasks.key: task_id}, - ) - - task_updated = task_query["last_updated"] - total_magnetization = task_query["calcs_reversed"][-1]["output"]["outcar"][ - "total_magnetization" - ] - - mat_doc.update( - { - "task_id": task_id, - "total_magnetization": total_magnetization, - "task_updated": task_updated, - self.materials.key: mat_doc[self.materials.key], - } - ) - - return mat_doc + ) diff --git a/emmet-builders/emmet/builders/materials/optimade.py b/emmet-builders/emmet/builders/materials/optimade.py deleted file mode 100644 index 3e1d88f9eb..0000000000 --- a/emmet-builders/emmet/builders/materials/optimade.py +++ /dev/null @@ -1,179 +0,0 @@ -import warnings -from math import ceil -from typing import TYPE_CHECKING - -from maggma.builders import Builder -from maggma.core import Store -from maggma.utils import grouper -from pymatgen.core.structure import Structure - -from emmet.core.optimade import OptimadeMaterialsDoc -from emmet.core.utils import jsanitize - -if TYPE_CHECKING: - from collections.abc import Iterator - -warnings.warn( - f"The current version of {__name__}.OptimadeMaterialsBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class OptimadeMaterialsBuilder(Builder): - def __init__( - self, - materials: Store, - thermo: Store, - optimade: Store, - query: dict | None = None, - **kwargs, - ): - """ - Creates Optimade compatible docs containing structure and thermo data for materials - - Args: - materials: Store of materials docs - thermo: Store of thermo docs - optimade: Store to update with optimade document - query : query on materials to limit search - """ - - self.materials = materials - self.thermo = thermo - self.optimade = optimade - self.query = query or {} - self.kwargs = kwargs - - # Enforce that we key on material_id - self.materials.key = "material_id" - self.thermo.key = "material_id" - self.optimade.key = "material_id" - - super().__init__(sources=[materials, thermo], targets=optimade, **kwargs) - - def prechunk(self, number_splits: int) -> Iterator[dict]: # pragma: no cover - """ - Prechunk method to perform chunking by the key field - """ - q = dict(self.query) - - keys = self.optimade.newer_in(self.materials, criteria=q, exhaustive=True) - - N = ceil(len(keys) / number_splits) - for split in grouper(keys, N): - yield {"query": {self.materials.key: {"$in": list(split)}}} - - def get_items(self) -> Iterator: - """ - Gets all items to process - - Returns: - Generator or list of relevant materials - """ - - self.logger.info("Optimade Builder Started") - - q = dict(self.query) - - q.update({"deprecated": False}) - - mat_ids = self.materials.distinct(self.materials.key, criteria=q) - opti_ids = self.optimade.distinct(self.optimade.key) - - mats_set = set( - self.optimade.newer_in(target=self.materials, criteria=q, exhaustive=True) - ) | (set(mat_ids) - set(opti_ids)) - - mats = [mat for mat in mats_set] - - self.total = len(mats) - - self.logger.info(f"Processing {self.total} items") - - for mat in mats: - doc = self._get_processed_doc(mat) - - if doc is not None: - yield doc - else: - pass - - def process_item(self, item): - mpid = item["mat_doc"]["material_id"] - structure = Structure.from_dict(item["mat_doc"]["structure"]) - last_updated_structure = item["mat_doc"]["last_updated"] - - # Functional names must be lowercase to adhere to optimade spec for querying attributes - thermo_calcs = {} - if item["thermo_doc"]: - for doc in item["thermo_doc"]: - thermo_calcs[doc["thermo_type"].lower()] = { - "thermo_id": doc["thermo_id"], - "energy_above_hull": doc["energy_above_hull"], - "formation_energy_per_atom": doc["formation_energy_per_atom"], - "last_updated_thermo": doc["last_updated"], - } - - optimade_doc = OptimadeMaterialsDoc.from_structure( - material_id=mpid, - structure=structure, - last_updated_structure=last_updated_structure, - thermo_calcs=thermo_calcs, - ) - - doc = jsanitize(optimade_doc.model_dump(), allow_bson=True) - - return doc - - def update_targets(self, items): - """ - Inserts the new optimade docs into the optimade collection - """ - docs = list(filter(None, items)) - - if len(docs) > 0: - self.logger.info(f"Found {len(docs)} optimade docs to update") - self.optimade.update(docs) - else: - self.logger.info("No items to update") - - def _get_processed_doc(self, mat): - mat_doc = self.materials.query_one( - {self.materials.key: mat}, [self.materials.key, "last_updated", "structure"] - ) - - mat_doc.update( - { - self.materials.key: mat_doc[self.materials.key], - } - ) - - # Query thermo store for all docs matching material_id to catch - # multiple stability calculations for the same material_id - thermo_docs = self.thermo.query( - {self.thermo.key: mat}, - [ - self.thermo.key, - "thermo_type", - "thermo_id", - "energy_above_hull", - "formation_energy_per_atom", - "last_updated", - ], - ) - - thermo_list = [doc for doc in thermo_docs] - - if thermo_list: - for doc in thermo_list: - doc.update({self.thermo.key: doc[self.thermo.key]}) - - combined_doc = { - "mat_doc": mat_doc, - "thermo_doc": None if not thermo_list else thermo_list, - } - - return combined_doc diff --git a/emmet-builders/emmet/builders/materials/oxidation_states.py b/emmet-builders/emmet/builders/materials/oxidation_states.py index 1b92bfa430..3c4e027875 100644 --- a/emmet-builders/emmet/builders/materials/oxidation_states.py +++ b/emmet-builders/emmet/builders/materials/oxidation_states.py @@ -1,65 +1,37 @@ -import warnings - -from maggma.builders.map_builder import MapBuilder -from maggma.core import Store -from pymatgen.core import Structure - +from emmet.builders.base import BaseBuilderInput +from emmet.builders.utils import filter_map from emmet.core.oxidation_states import OxidationStateDoc -from emmet.core.utils import jsanitize - -warnings.warn( - f"The current version of {__name__}.OxidationStatesBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class OxidationStatesBuilder(MapBuilder): - def __init__( - self, - materials: Store, - oxidation_states: Store, - query=None, - **kwargs, - ): - """ - Creates Oxidation State documents from materials - Args: - materials: Store of materials docs - oxidation_states: Store to update with oxidation state document - query : query on materials to limit search - """ - self.materials = materials - self.oxidation_states = oxidation_states - self.kwargs = kwargs - self.query = query or {} - # Enforce that we key on material_id - self.materials.key = "material_id" - self.oxidation_states.key = "material_id" - super().__init__( - source=materials, - target=oxidation_states, - projection=["structure", "deprecated", "builder_meta"], - query=query, - **kwargs, +def build_oxidation_states_docs( + input_documents: list[BaseBuilderInput], **kwargs +) -> list[OxidationStateDoc]: + """ + Generate oxidation state documents from input structures. + + Transforms a list of BaseBuilderInput documents containing + Pymatgen structures into corresponding OxidationStateDoc instances by + analyzing the oxidation states of each structure. + + Caller is responsible for creating BaseBuilderInput instances + within their data pipeline context. + + Args: + input_documents: List of BaseBuilderInput documents to process. + + Returns: + list[OxidationStateDoc] + """ + return list( + filter_map( + OxidationStateDoc.from_structure, + input_documents, + work_keys=[ + "deprecated", + "material_id", + "structure", + "builder_meta", + ], + **kwargs ) - - def unary_function(self, item): - structure = Structure.from_dict(item["structure"]) - mpid = item["material_id"] - deprecated = item["deprecated"] - builder_meta = item["builder_meta"] - - oxi_doc = OxidationStateDoc.from_structure( - structure=structure, - material_id=mpid, - deprecated=deprecated, - builder_meta=builder_meta, - ) - doc = jsanitize(oxi_doc.model_dump(), allow_bson=True) - - return doc + ) diff --git a/emmet-builders/emmet/builders/materials/piezoelectric.py b/emmet-builders/emmet/builders/materials/piezoelectric.py deleted file mode 100644 index f76a1124be..0000000000 --- a/emmet-builders/emmet/builders/materials/piezoelectric.py +++ /dev/null @@ -1,264 +0,0 @@ -import warnings -from math import ceil - -import numpy as np -from maggma.builders import Builder -from maggma.core import Store -from maggma.utils import grouper -from pymatgen.core.structure import Structure - -from emmet.core.mpid import AlphaID -from emmet.core.polar import PiezoelectricDoc -from emmet.core.utils import jsanitize - -warnings.warn( - f"The current version of {__name__}.PiezoelectricBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class PiezoelectricBuilder(Builder): - def __init__( - self, - materials: Store, - tasks: Store, - piezoelectric: Store, - query: dict | None = None, - **kwargs, - ): - self.materials = materials - self.tasks = tasks - self.piezoelectric = piezoelectric - self.query = query or {} - self.kwargs = kwargs - - self.materials.key = "material_id" - self.tasks.key = "task_id" - self.piezoelectric.key = "material_id" - - super().__init__( - sources=[materials, tasks], - targets=[piezoelectric], - **kwargs, - ) - - def prechunk(self, number_splits: int): # pragma: no cover - """ - Prechunk method to perform chunking by the key field - """ - q = dict(self.query) - - # Ensure no centrosymmetry - q.update( - { - "symmetry.point_group": { - "$nin": [ - "-1", - "2/m", - "mmm", - "4/m", - "4/mmm", - "-3", - "-3m", - "6/m", - "6/mmm", - "m-3", - "m-3m", - ] - } - } - ) - - keys = self.piezoelectric.newer_in(self.materials, criteria=q, exhaustive=True) - - N = ceil(len(keys) / number_splits) - for split in grouper(keys, N): - yield {"query": {self.materials.key: {"$in": list(split)}}} - - def get_items(self): - """ - Gets all items to process - - Returns: - generator or list relevant tasks and materials to process - """ - - self.logger.info("Piezoelectric Builder Started") - - q = dict(self.query) - - # Ensure no centrosymmetry - q.update( - { - "symmetry.point_group": { - "$nin": [ - "-1", - "2/m", - "mmm", - "4/m", - "4/mmm", - "-3", - "-3m", - "6/m", - "6/mmm", - "m-3", - "m-3m", - ] - } - } - ) - - mat_ids = self.materials.distinct(self.materials.key, criteria=q) - piezo_ids = self.piezoelectric.distinct(self.piezoelectric.key) - - mats_set = set( - self.piezoelectric.newer_in( - target=self.materials, criteria=q, exhaustive=True - ) - ) | (set(mat_ids) - set(piezo_ids)) - - mats = [mat for mat in mats_set] - - self.logger.info( - "Processing {} materials for piezoelectric data".format(len(mats)) - ) - - self.total = len(mats) - - for mat in mats: - doc = self._get_processed_doc(mat) - - if doc is not None: - yield doc - else: - pass - - def process_item(self, item): - structure = Structure.from_dict(item["structure"]) - mpid = item["material_id"] - origin_entry = { - "name": "piezoelectric", - "task_id": item["task_id"], - "last_updated": item["task_updated"], - } - - doc = PiezoelectricDoc.from_ionic_and_electronic( - structure=structure, - material_id=mpid, - origins=[origin_entry], - deprecated=False, - ionic=item["piezo_ionic"], - electronic=item["piezo_static"], - last_updated=item["updated_on"], - ) - - return jsanitize(doc.model_dump(), allow_bson=True) - - def update_targets(self, items): - """ - Inserts the new dielectric docs into the dielectric collection - """ - docs = list(filter(None, items)) - - if len(docs) > 0: - self.logger.info(f"Found {len(docs)} piezoelectric docs to update") - self.piezoelectric.update(docs) - else: - self.logger.info("No items to update") - - def _get_processed_doc(self, mat): - mat_doc = self.materials.query_one( - {self.materials.key: mat}, - [ - self.materials.key, - "structure", - "task_types", - "run_types", - "deprecated_tasks", - "last_updated", - ], - ) - - task_types = mat_doc["task_types"].items() - - potential_task_ids = [] - - for task_id, task_type in task_types: - if task_type == "DFPT Dielectric": - if task_id not in mat_doc["deprecated_tasks"]: - potential_task_ids.append(task_id) - - final_docs = [] - - for task_id in potential_task_ids: - task_query = self.tasks.query_one( - properties=[ - "last_updated", - "input.is_hubbard", - "orig_inputs.kpoints", - "orig_inputs.structure", - "input.parameters", - "input.structure", - "calcs_reversed", - "output.bandgap", - ], - criteria={self.tasks.key: str(AlphaID(task_id))}, - ) - if task_query["output"]["bandgap"] > 0: - try: - structure = task_query["input"]["structure"] - except KeyError: - structure = task_query["orig_inputs"]["structure"] - - is_hubbard = task_query["input"]["is_hubbard"] - - if ( - task_query["orig_inputs"]["kpoints"]["generation_style"] - == "Monkhorst" - or task_query["orig_inputs"]["kpoints"]["generation_style"] - == "Gamma" - ): - nkpoints = np.prod( - task_query["orig_inputs"]["kpoints"]["kpoints"][0], axis=0 - ) - - else: - nkpoints = task_query["orig_inputs"]["kpoints"]["nkpoints"] - - lu_dt = mat_doc["last_updated"] - task_updated = task_query["last_updated"] - - if (cr := task_query.get("calcs_reversed", [])) and ( - outcar := cr[0].get("output", {}).get("outcar", {}) - ): - - final_docs.append( - { - "task_id": task_id, - "is_hubbard": int(is_hubbard), - "nkpoints": int(nkpoints), - "piezo_static": outcar.get("piezo_tensor"), - "piezo_ionic": outcar.get("piezo_ionic_tensor"), - "structure": structure, - "updated_on": lu_dt, - "task_updated": task_updated, - self.materials.key: mat_doc[self.materials.key], - } - ) - - if len(final_docs) > 0: - sorted_final_docs = sorted( - final_docs, - key=lambda entry: ( - entry["is_hubbard"], - entry["nkpoints"], - entry["updated_on"], - ), - reverse=True, - ) - return sorted_final_docs[0] - else: - return None diff --git a/emmet-builders/emmet/builders/materials/provenance.py b/emmet-builders/emmet/builders/materials/provenance.py index eb4f08e215..a1ba30c5c1 100644 --- a/emmet-builders/emmet/builders/materials/provenance.py +++ b/emmet-builders/emmet/builders/materials/provenance.py @@ -1,278 +1,145 @@ -import warnings -from collections import defaultdict -from math import ceil -from typing import TYPE_CHECKING +"""Build provenance collection.""" -from maggma.core import Builder, Store -from maggma.utils import grouper from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher -from pymatgen.core.structure import Structure +from emmet.builders.base import BaseBuilderInput from emmet.builders.settings import EmmetBuildSettings -from emmet.core.provenance import ProvenanceDoc, SNLDict -from emmet.core.utils import get_sg, jsanitize, utcnow - -if TYPE_CHECKING: - from collections.abc import Iterable - -warnings.warn( - f"The current version of {__name__}.ProvenanceBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, +from emmet.builders.utils import filter_map +from emmet.core.connectors.analysis import parse_cif +from emmet.core.connectors.icsd.client import IcsdClient +from emmet.core.connectors.icsd.enums import IcsdSubset +from emmet.core.provenance import DatabaseSNL, ProvenanceDoc + +SETTINGS = EmmetBuildSettings() +structure_matcher = StructureMatcher( + ltol=SETTINGS.LTOL, + stol=SETTINGS.STOL, + comparator=ElementComparator(), + angle_tol=SETTINGS.ANGLE_TOL, + primitive_cell=True, + scale=True, + attempt_supercell=False, + allow_subset=False, ) -class ProvenanceBuilder(Builder): - def __init__( - self, - materials: Store, - provenance: Store, - source_snls: list[Store], - settings: EmmetBuildSettings | None = None, - query: dict | None = None, - **kwargs, - ): - """ - Creates provenance from source SNLs and materials - - Args: - materials: Store of materials docs to tag with SNLs - provenance: Store to update with provenance data - source_snls: List of locations to grab SNLs - query : query on materials to limit search - """ - self.materials = materials - self.provenance = provenance - self.source_snls = source_snls - self.settings = EmmetBuildSettings.autoload(settings) - self.query = query or {} - self.kwargs = kwargs - - materials.key = "material_id" - provenance.key = "material_id" - for s in source_snls: - s.key = "snl_id" - - super().__init__( - sources=[materials, *source_snls], targets=[provenance], **kwargs - ) - - def ensure_indicies(self): - self.materials.ensure_index("material_id", unique=True) - self.materials.ensure_index("formula_pretty") - - self.provenance.ensure_index("material_id", unique=True) - self.provenance.ensure_index("formula_pretty") - - for s in self.source_snls: - s.ensure_index("snl_id") - s.ensure_index("formula_pretty") - - def prechunk(self, number_splits: int) -> Iterable[dict]: # pragma: no cover - self.ensure_indicies() - - # Find all formulas for materials that have been updated since this - # builder was last ran - q = self.query - updated_materials = self.provenance.newer_in( - self.materials, criteria=q, exhaustive=True +def _get_snl_from_cif(cif_str: str, **kwargs) -> DatabaseSNL | None: + """Build a database SNL from a CIF plus its metadata. + + NB: Only takes the first structure from a CIF. + While a CIF can technically contain many structures, + the ICSD usually only distributes CIFs with one structure + per file. + + Parameters + ----------- + cif_str : the CIF to parse + **kwargs to pass to `DatabaseSNL` + """ + try: + structures, cif_parsing_remarks = parse_cif(cif_str) + remarks = kwargs.pop("remarks", None) or cif_parsing_remarks or None + snl = DatabaseSNL.from_structure( + meta_structure=structures[0], + structure=structures[0], + remarks=remarks, + **kwargs, ) - forms_to_update = set( - self.materials.distinct( - "formula_pretty", {"material_id": {"$in": updated_materials}} - ) - ) - - # Find all new SNL formulas since the builder was last run - for source in self.source_snls: - new_snls = self.provenance.newer_in(source) - forms_to_update |= set( - source.distinct("formula_pretty", {source.key: {"$in": new_snls}}) + except Exception: + return None + + if snl and snl.remarks is None: + return snl + return None + + +def update_experimental_icsd_structures(**client_kwargs) -> list[DatabaseSNL]: + """Update the collection of ICSD SNLs. + + Parameters + ----------- + **client_kwargs to pass to `IcsdClient` + + Returns + ----------- + List of DatabaseSNL + """ + data = [] + with IcsdClient(use_document_model=False, **client_kwargs) as client: + for icsd_subset in ( + IcsdSubset.EXPERIMENTAL_METALORGANIC, + IcsdSubset.EXPERIMENTAL_INORGANIC, + ): + data += client.search( + subset=IcsdSubset.EXPERIMENTAL_INORGANIC, + space_group_number=(1, 230), + include_cif=True, + include_metadata=False, ) - # Now reduce to the set of formulas we actually have - forms_avail = set(self.materials.distinct("formula_pretty", self.query)) - forms_to_update = forms_to_update & forms_avail - - mat_ids = set( - self.materials.distinct( - "material_id", {"formula_pretty": {"$in": list(forms_to_update)}} - ) - ) & set(updated_materials) - - N = ceil(len(mat_ids) / number_splits) - - self.logger.info( - f"Found {len(mat_ids)} new/updated systems to distribute to workers " - f"in {N} chunks." - ) - - for chunk in grouper(mat_ids, N): - yield {"query": {"material_id": {"$in": chunk}}} - - def get_items(self) -> tuple[list[dict], list[dict]]: # type: ignore - """ - Gets all materials to assocaite with SNLs - Returns: - generator of materials and SNLs that could match - """ - self.logger.info("Provenance Builder Started") - - self.logger.info("Setting indexes") - self.ensure_indicies() - - # Find all formulas for materials that have been updated since this - # builder was last ran - q = self.query - updated_materials = self.provenance.newer_in( - self.materials, criteria=q, exhaustive=True - ) - forms_to_update = set( - self.materials.distinct( - "formula_pretty", {"material_id": {"$in": updated_materials}} - ) + parsed = [ + _get_snl_from_cif( + doc["cif"], + snl_id=f"icsd-{doc['collection_code']}", + tags=[doc["subset"].value], + source="icsd", ) - - # Find all new SNL formulas since the builder was last run - for source in self.source_snls: - new_snls = self.provenance.newer_in(source) - forms_to_update |= set( - source.distinct("formula_pretty", {source.key: {"$in": new_snls}}) - ) - - # Now reduce to the set of formulas we actually have - forms_avail = set(self.materials.distinct("formula_pretty", self.query)) - forms_to_update = forms_to_update & forms_avail - - mat_ids = set( - self.materials.distinct( - "material_id", {"formula_pretty": {"$in": list(forms_to_update)}} - ) - ) & set(updated_materials) - - self.total = len(mat_ids) - - self.logger.info(f"Found {self.total} new/updated systems to process") - - for mat_id in mat_ids: - mat = self.materials.query_one( - properties=[ - "material_id", - "last_updated", - "structure", - "initial_structures", - "formula_pretty", - "deprecated", - ], - criteria={"material_id": mat_id}, - ) - - snls = [] # type: list - for source in self.source_snls: - snls.extend( - source.query(criteria={"formula_pretty": mat["formula_pretty"]}) - ) - - snl_groups = defaultdict(list) - for snl in snls: - struc = Structure.from_dict(snl) - snl_sg = get_sg(struc) - struc.snl = SNLDict(**snl) # type: ignore[attr-defined] - snl_groups[snl_sg].append(struc) - - mat_sg = get_sg(Structure.from_dict(mat["structure"])) - - snl_structs = snl_groups[mat_sg] - - self.logger.debug(f"Found {len(snl_structs)} potential snls for {mat_id}") - yield mat, snl_structs - - def process_item(self, item) -> dict: - """ - Matches SNLS and Materials - Args: - item (tuple): a tuple of materials and snls - Returns: - list(dict): a list of collected snls with material ids - """ - mat, snl_structs = item - formula_pretty = mat["formula_pretty"] - snl_doc = None - self.logger.debug(f"Finding Provenance {formula_pretty}") - - # Match up SNLS with materials - - matched_snls = self.match(snl_structs, mat) - - if len(matched_snls) > 0: - doc = ProvenanceDoc.from_SNLs( - material_id=mat["material_id"], - structure=Structure.from_dict(mat["structure"]), - snls=matched_snls, - deprecated=mat["deprecated"], - ) - else: - doc = ProvenanceDoc( # type: ignore[call-arg] - material_id=mat["material_id"], - structure=Structure.from_dict(mat["structure"]), - deprecated=mat["deprecated"], - created_at=utcnow(), - ) - - doc.authors.append(self.settings.DEFAULT_AUTHOR) - doc.history.append(self.settings.DEFAULT_HISTORY) # type: ignore[union-attr] - doc.references.append(self.settings.DEFAULT_REFERENCE) - - snl_doc = jsanitize(doc.dict(exclude_none=False), allow_bson=True) - - return snl_doc - - def match(self, snl_structs, mat): - """ - Finds a material doc that matches with the given snl - Args: - snl_structs ([dict]): the snls struct list - mat (dict): a materials doc - Returns: - generator of materials doc keys - """ - - m_strucs = [Structure.from_dict(mat["structure"])] + [ - Structure.from_dict(init_struc) for init_struc in mat["initial_structures"] - ] - - sm = StructureMatcher( - ltol=self.settings.LTOL, - stol=self.settings.STOL, - angle_tol=self.settings.ANGLE_TOL, - primitive_cell=True, - scale=True, - attempt_supercell=False, - allow_subset=False, - comparator=ElementComparator(), + for doc in data + ] + + return sorted( + [doc for doc in parsed if doc], + key=lambda doc: int(doc.snl_id.split("-", 1)[-1]), + ) + + +def match_against_snls( + input_doc: BaseBuilderInput, + snls: list[DatabaseSNL], +): + """Match a single document against the SNL collection.""" + database_ids = {} + authors = [SETTINGS.DEFAULT_AUTHOR] + history = [SETTINGS.DEFAULT_HISTORY] + references = [SETTINGS.DEFAULT_REFERENCE] + theoretical = True + + for snl in [ + doc + for doc in snls + if doc.chemsys + == ( + "-".join(sorted(input_doc.structure.composition.chemical_system.split("-"))) ) - - snls = [] - - for s in m_strucs: - for snl_struc in snl_structs: - if sm.fit(s, snl_struc): - if snl_struc.snl not in snls: - snls.append(snl_struc.snl) - - self.logger.debug(f"Found {len(snls)} SNLs for {mat['material_id']}") - return snls - - def update_targets(self, items): - """ - Inserts the new SNL docs into the SNL collection - """ - snls = list(filter(None, items)) - - if len(snls) > 0: - self.logger.info(f"Found {len(snls)} SNLs to update") - self.provenance.update(snls) - else: - self.logger.info("No items to update") + ]: + if structure_matcher.fit(input_doc.structure, snl.structure): + + if snl.source and snl.source in {"icsd", "pauling"}: + theoretical = False + database_ids[snl.source].append(snl.snl_id) + + if snl.about: + authors.extend(snl.about.authors or []) + history.extend(snl.about.history or []) + # `SNLAbout` uses string for `references`, + # `ProvenanceDoc` uses list of str + if snl.about.references: + references.append(snl.about.references) + + return ProvenanceDoc.from_structure( + meta_structure=input_doc.structure, + material_id=input_doc.material_id, + database_IDs=database_ids, + theoretical=theoretical, + authors=authors, + history=history, + references=references, + ) + + +def build_provenance_docs( + input_documents: list[BaseBuilderInput], snls: list[DatabaseSNL], **kwargs +) -> list[ProvenanceDoc]: + """Build the provenance collection.""" + + return list(filter_map(match_against_snls, input_documents, snls=snls, **kwargs)) diff --git a/emmet-builders/emmet/builders/materials/robocrys.py b/emmet-builders/emmet/builders/materials/robocrys.py index e694598401..08765cffa4 100644 --- a/emmet-builders/emmet/builders/materials/robocrys.py +++ b/emmet-builders/emmet/builders/materials/robocrys.py @@ -1,54 +1,38 @@ -import warnings - -from maggma.builders.map_builder import MapBuilder -from maggma.core import Store -from pymatgen.core.structure import Structure +from robocrys import __version__ as __robocrys_version__ +from robocrys.condense.mineral import MineralMatcher +from emmet.builders.base import BaseBuilderInput +from emmet.builders.utils import filter_map from emmet.core.robocrys import RobocrystallogapherDoc -from emmet.core.utils import jsanitize - -warnings.warn( - f"The current version of {__name__}.RobocrystallographerBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class RobocrystallographerBuilder(MapBuilder): - def __init__( - self, - oxidation_states: Store, - robocrys: Store, - query: dict | None = None, - **kwargs, - ): - self.oxidation_states = oxidation_states - self.robocrys = robocrys - self.kwargs = kwargs - self.robocrys.key = "material_id" - self.oxidation_states.key = "material_id" - super().__init__( - source=oxidation_states, - target=robocrys, - query=query, - projection=["material_id", "structure", "deprecated"], - **kwargs, +def build_robocrys_docs( + input_documents: list[BaseBuilderInput], **kwargs +) -> list[RobocrystallogapherDoc]: + """ + Generate robocrystallographer descriptions from input structures. + + Transforms a list of BaseBuilderInput documents containing + Pymatgen structures into corresponding RobocrystallogapherDoc instances by + using robocrys' StructureCondenser and StructureDescriber classes. + + Caller is responsible for creating BaseBuilderInput instances + within their data pipeline context. + + Args: + input_documents: List of BaseBuilderInput documents to process. + + Returns: + list[RobocrystallogapherDoc] + """ + mineral_matcher = MineralMatcher() + return list( + filter_map( + RobocrystallogapherDoc.from_structure, + input_documents, + work_keys=["deprecated", "material_id", "structure", "origins"], + mineral_matcher=mineral_matcher, + robocrys_version=__robocrys_version__, + **kwargs ) - - def unary_function(self, item): - structure = Structure.from_dict(item["structure"]) - mpid = item["material_id"] - deprecated = item["deprecated"] - - doc = RobocrystallogapherDoc.from_structure( - structure=structure, - material_id=mpid, - deprecated=deprecated, - fields=[], - ) - - return jsanitize(doc.model_dump(), allow_bson=True) + ) diff --git a/emmet-builders/emmet/builders/materials/similarity.py b/emmet-builders/emmet/builders/materials/similarity.py index 2fd380cbb8..1b818f7e14 100644 --- a/emmet-builders/emmet/builders/materials/similarity.py +++ b/emmet-builders/emmet/builders/materials/similarity.py @@ -1,160 +1,120 @@ -import warnings - import numpy as np -from maggma.builders import Builder - -__author__ = "Nils E. R. Zimmermann " - -# TODO: -# 1) ADD DOCUMENT MODEL - -warnings.warn( - f"The current version of {__name__}.StructureSimilarityBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, +from emmet.builders.base import BaseBuilderInput +from emmet.core.similarity import ( + CrystalNNSimilarity, + M3GNetSimilarity, + SimilarityDoc, + SimilarityEntry, + SimilarityMethod, ) - -class StructureSimilarityBuilder(Builder): - def __init__(self, site_descriptors, structure_similarity, fp_type="csf", **kwargs): - """ - Calculates similarity metrics between structures on the basis - of site descriptors. - - Args: - site_descriptors (Store): storage of site-descriptors data - such as tetrahedral order parameter - or percentage of 8-fold coordination. - structure_similarity (Store): storage of structure similarity - metrics. - fp_type (str): target site fingerprint type to be - used for similarity computation - ("csf" (based on matminer's - CrystalSiteFingerprint class) - or "opsf" (based on matminer's - OPSiteFingerprint class)). - """ - - self.site_descriptors = site_descriptors - self.structure_similarity = structure_similarity - self.fp_type = fp_type - - super().__init__( - sources=[site_descriptors], targets=[structure_similarity], **kwargs +SIM_METHOD_TO_SCORER = { + SimilarityMethod(k): v + for k, v in { + "CrystalNN": CrystalNNSimilarity, + "M3GNet": M3GNetSimilarity, + }.items() +} + + +class SimilarityBuilderInput(BaseBuilderInput): + """Augment base builder input with extra fields.""" + + similarity_method: SimilarityMethod + feature_vector: list[float] + + +# this could probably be parallelized over `similarity_method` +def build_feature_vectors( + input_documents: list[BaseBuilderInput], + similarity_method: SimilarityMethod | str = SimilarityMethod.CRYSTALNN, +) -> list[SimilarityBuilderInput]: + """Generate similarity feature vectors. + + Args: + input_documents : list of BaseBuilderInput to process + similarity_method : SimilarityMethod = SimilarityMethod.CRYSTALNN + The method to use in building similarity docs. + Returns: + list of SimilarityBuilderInput + """ + if isinstance(similarity_method, str): + similarity_method = ( + SimilarityMethod[similarity_method] + if similarity_method in SimilarityMethod.__members__ + else SimilarityMethod(similarity_method) ) - def get_items(self): - """ - Gets all materials that need new site descriptors. - - Returns: - generator of materials to calculate site descriptors. - """ - - self.logger.info("Structure Similarity Builder Started") - - self.logger.info("Setting indexes") - - # TODO: re-introduce last-updated filtering. - task_ids = list(self.site_descriptors.distinct(self.site_descriptors.key)) - n_task_ids = len(task_ids) - for i in range(n_task_ids - 1): - d1 = self.site_descriptors.query_one( - properties=[self.site_descriptors.key, "statistics"], - criteria={self.site_descriptors.key: task_ids[i]}, - ) - for j in range(i + 1, n_task_ids): - d2 = self.site_descriptors.query_one( - properties=[self.site_descriptors.key, "statistics"], - criteria={self.site_descriptors.key: task_ids[j]}, - ) - yield list([d1, d2]) - - def process_item(self, item): - """ - Calculates site descriptors for the structures - - Args: - item (list): a list (length 2) with each one document that - carries a task ID in "task_id" and a statistics - vector from OP site-fingerprints in - "statistics". - - Returns: - dict: similarity measures. - """ - self.logger.debug( - "Similarities for {} and {}".format( - item[0][self.site_descriptors.key], item[1][self.site_descriptors.key] - ) + if scorer_cls := SIM_METHOD_TO_SCORER.get(similarity_method): + scorer = scorer_cls() + else: + raise ValueError(f"Unsupported {similarity_method=}") + + return list( + map( + lambda x: SimilarityBuilderInput( + material_id=x.material_id, + structure=x.structure, + similarity_method=similarity_method, + feature_vector=scorer._featurize_structure(x.structure), + ), + input_documents, ) - - sim_doc = {} - sim_doc = self.get_similarities(item[0], item[1]) - sim_doc[self.structure_similarity.key] = tuple( - sorted( - [item[0][self.site_descriptors.key], item[1][self.site_descriptors.key]] - ) + ) + + +def build_similarity_docs( + input_documents: list[SimilarityBuilderInput], + num_closest: int = 100, +) -> list[SimilarityDoc]: + """Generate similarity feature vectors. + + All input docs should use the same similarity method. + A check is performed at the start to ensure this. + + Args: + input_documents : list of SimilarityBuilderInput to process + num_closest : int = 100 + The number of most similar materials to identify + for each material + Returns: + list of SimilarityDoc + """ + + if ( + len(distinct_sim_methods := {doc.similarity_method for doc in input_documents}) + > 1 + ): + raise ValueError( + f"Multiple similarity methods found: {', '.join(distinct_sim_methods)}" ) - return sim_doc - - def update_targets(self, items): - """ - Inserts the new task_types into the task_types collection. - - Args: - items ([[dict]]): a list of list of site-descriptors dictionaries to update. - """ - if len(items) > 0: - self.logger.info("Updating {} structure-similarity docs".format(len(items))) - self.structure_similarity.update(docs=items) - else: - self.logger.info("No items to update") - - def get_similarities(self, d1, d2): - doc = {} - - # Compute similarty metrics. - try: - dout = {} - l = {} - v = {} - for i, li in enumerate( - [d1["statistics"][self.fp_type], d2["statistics"][self.fp_type]] - ): - v[i] = [] - l[i] = [] - # for optype, stats in d.items(): - for opdict in li: - for stattype, val in opdict.items(): - if stattype != "name": - v[i].append(val) - l[i].append("{} {}".format(opdict["name"], stattype)) - if len(l[0]) != len(l[1]): - raise RuntimeError( - "Site-fingerprint statistics dictionaries" - " have different sizes ({}, {})".format(len(l[0]), len(l[1])) - ) - for k in l[0]: - if k not in l[1]: - raise RuntimeError( - 'Label "{}" not found in second site-' - "fingerprint statistics " - "dictionary".format(k) + scorer_cls = SIM_METHOD_TO_SCORER[method := input_documents[0].similarity_method] + material_ids, vectors, structures = np.array( + [doc.material_id, doc.feature_vector, doc.structure] for doc in input_documents + ).T + + similarity_docs = [] + for i, material_id in enumerate(material_ids): + closest_idxs, closest_dist = scorer_cls._get_closest_vectors( + i, vectors, num_closest + ) + similarity_docs.append( + SimilarityDoc.from_structure( + meta_structure=structures[i], + material_id=material_id, + feature_vector=vectors[i], + method=method, + sim=[ + SimilarityEntry( + task_id=material_ids[jdx], + nelements=len(structures[jdx].composition.elements), + dissimilarity=100.0 - closest_dist[j], + formula=structures[jdx].formula, ) - v1 = np.array([v[0][k] for k in range(len(l[0]))]) - v2 = np.array([v[1][l[1].index(k)] for k in l[0]]) - dout["cos"] = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) - dout["dist"] = np.linalg.norm(v1 - v2) - doc = dout - - except Exception as e: - self.logger.error( - "Failed calculating structure similarity" "metrics: {}".format(e) + for j, jdx in enumerate(closest_idxs) + ], ) - - return doc + ) + return similarity_docs diff --git a/emmet-builders/emmet/builders/materials/substrates.py b/emmet-builders/emmet/builders/materials/substrates.py deleted file mode 100644 index 5fbb11edf3..0000000000 --- a/emmet-builders/emmet/builders/materials/substrates.py +++ /dev/null @@ -1,192 +0,0 @@ -import warnings -from typing import Iterable - -from maggma.core.builder import Builder -from maggma.core.store import Store -from maggma.utils import grouper -from pymatgen.analysis.elasticity.elastic import ElasticTensor -from pymatgen.core.structure import Structure -from pymatgen.symmetry.analyzer import SpacegroupAnalyzer - -from emmet.core.mpid import AlphaID -from emmet.core.substrates import SubstratesDoc -from emmet.core.utils import jsanitize - -warnings.warn( - f"The current version of {__name__}.SubstratesBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class SubstratesBuilder(Builder): - def __init__( - self, - materials: Store, - substrates: Store, - elasticity: Store, - query: dict | None = None, - **kwargs, - ): - """ - Calculates matching substrates - - Args: - materials (Store): Store of materials documents - diffraction (Store): Store of substrate matches - elasticity (Store): Store of elastic tensor documents - substrates_file (path): file of substrates to consider - query (dict): dictionary to limit materials to be analyzed - """ - self.materials = materials - self.substrates = substrates - self.elasticity = elasticity - self.query = query - self.kwargs = kwargs - - # Enforce that we key on material_id - self.materials.key = "material_id" - self.substrates.key = "material_id" - self.elasticity.key = "material_id" - - super().__init__( - sources=[materials, elasticity], - targets=[substrates], - **kwargs, - ) - - def prechunk(self, number_splits: int) -> Iterable[dict]: # pragma: no cover - to_process_mat_ids = self._find_to_process() - - return [ - {"material_id": {"$in": list(chunk)}} - for chunk in grouper(to_process_mat_ids, number_splits) - ] - - def get_items(self): - """ - Gets all materials that need new substrates - - Returns: - generator of materials to calculate substrates - """ - - to_process_mat_ids = self._find_to_process() - - self.logger.info( - "Updating all substrate calculations for {} materials".format( - len(to_process_mat_ids) - ) - ) - - for mpid in to_process_mat_ids: - e_tensor = self.elasticity.query_one( - criteria={self.elasticity.key: mpid}, - properties=["elasticity", "last_updated"], - ) - e_tensor = ( - e_tensor.get("elasticity", {}).get("elastic_tensor", None) - if e_tensor - else None - ) - mat = self.materials.query_one( - criteria={self.materials.key: mpid}, - properties=["structure", "deprecated", "material_id", "last_updated"], - ) - - yield { - "structure": mat["structure"], - "material_id": mat[self.materials.key], - "elastic_tensor": e_tensor, - "deprecated": mat["deprecated"], - "last_updated": max( - mat.get("last_updated"), e_tensor.get("last_updated") - ), - } - - def process_item(self, item): - """ - Calculates substrate matches for all given substrates - - Args: - item (dict): a dict with a material_id and a structure - - Returns: - dict: a diffraction dict - """ - - mpid = AlphaID(item["material_id"]) - elastic_tensor = item.get("elastic_tensor", None) - elastic_tensor = ( - ElasticTensor.from_voigt(elastic_tensor) if elastic_tensor else None - ) - deprecated = item["deprecated"] - - self.logger.debug("Calculating substrates for {}".format(item["task_id"])) - - # Ensure we're using conventional standard to be consistent with IEEE elastic tensor setting - film = conventional_standard_structure(item) - - substrate_doc = SubstratesDoc.from_structure( - material_id=mpid, - structure=film, - elastic_tensor=elastic_tensor, - deprecated=deprecated, - last_updated=item["last_updated"], - ) - - return jsanitize(substrate_doc.model_dump(), allow_bson=True) - - def update_targets(self, items): - """ - Inserts the new substrate matches into the substrates collection - - Args: - items ([[dict]]): a list of list of thermo dictionaries to update - """ - - items = list(filter(None, items)) - - if len(items) > 0: - self.logger.info("Updating {} substrate matches".format(len(items))) - self.substrates.update(docs=items) - else: - self.logger.info("No items to update") - - def ensure_indicies(self): - """ - Ensures indicies on the substrates, materials, and elastic collections - """ - # Search indicies for materials - self.materials.ensure_index(self.materials.key) - self.materials.ensure_index(self.materials.last_updated_field) - - # Search indicies for elasticity - self.elasticity.ensure_index(self.elasticity.key) - self.elasticity.ensure_index(self.elasticity.last_updated_field) - - # Search indicies for substrates - self.substrates.ensure_index(self.substrates.key) - self.substrates.ensure_index(self.substrates.last_updated_field) - - def _find_to_process(self): - self.logger.info("Substrate Builder Started") - - self.logger.info("Setting up indicies") - self.ensure_indicies() - - mat_keys = set(self.materials.distinct("material_id", criteria=self.query)) - updated_mats = self.materials.newer_in(self.substrates) - e_tensor_updated_mats = self.elasticity.newer_in(self.substrates) - - # To ensure all mats are within our scope - return set(e_tensor_updated_mats + updated_mats) & mat_keys - - -def conventional_standard_structure(doc): - """Get a conventional standard structure from doc["structure"].""" - s = Structure.from_dict(doc["structure"]) - spga = SpacegroupAnalyzer(s, symprec=0.1) - return spga.get_conventional_standard_structure() diff --git a/emmet-builders/emmet/builders/materials/summary.py b/emmet-builders/emmet/builders/materials/summary.py index 4a8e753372..188124894b 100644 --- a/emmet-builders/emmet/builders/materials/summary.py +++ b/emmet-builders/emmet/builders/materials/summary.py @@ -1,243 +1,93 @@ -import warnings -from math import ceil - -from maggma.builders import Builder -from maggma.utils import grouper - -from emmet.core.mpid import AlphaID -from emmet.core.summary import HasProps, SummaryDoc -from emmet.core.types.enums import ThermoType -from emmet.core.utils import jsanitize - -warnings.warn( - f"The current version of {__name__}.SummaryBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, +from pydantic import BaseModel + +from emmet.builders.utils import filter_map +from emmet.core.summary import ( + AbsorptionData, + BandstructureSummary, + ChargeDensityData, + ChemenvData, + DielectricSummary, + DosSummary, + ElasticitySummary, + ElectrodesData, + ElectronicStructureSummary, + EosData, + GBSummary, + MagnetismSummary, + MaterialsSummary, + OxiStatesSummary, + PhononData, + PiezoelectricSummary, + ProvenanceSummary, + SubstratesData, + SummaryDoc, + SurfacesSummary, + ThermoSummary, + XASSummary, ) -class SummaryBuilder(Builder): - def __init__( - self, - materials, - thermo, - xas, - chemenv, - absorption, - grain_boundaries, - electronic_structure, - magnetism, - elasticity, - dielectric, - piezoelectric, - phonon, - insertion_electrodes, - substrates, - surfaces, - oxi_states, - eos, - provenance, - charge_density_index, - summary, - thermo_type=ThermoType.GGA_GGA_U_R2SCAN.value, - chunk_size=100, - query=None, - **kwargs, - ): - self.materials = materials - self.thermo = thermo - self.xas = xas - self.chemenv = chemenv - self.absorption = absorption - self.grain_boundaries = grain_boundaries - self.electronic_structure = electronic_structure - self.magnetism = magnetism - self.elasticity = elasticity - self.dielectric = dielectric - self.piezoelectric = piezoelectric - self.phonon = phonon - self.insertion_electrodes = insertion_electrodes - self.substrates = substrates - self.surfaces = surfaces - self.oxi_states = oxi_states - self.eos = eos - self.provenance = provenance - self.charge_density_index = charge_density_index - - self.thermo_type = thermo_type - - self.summary = summary - self.chunk_size = chunk_size - self.query = query if query else {} - - super().__init__( - sources=[ - materials, - thermo, - xas, - chemenv, - absorption, - grain_boundaries, - electronic_structure, - magnetism, - elasticity, - dielectric, - piezoelectric, - phonon, - insertion_electrodes, - surfaces, - oxi_states, - substrates, - eos, - provenance, - charge_density_index, - ], - targets=[summary], - chunk_size=chunk_size, - **kwargs, - ) - - def get_items(self): - """ - Gets all items to process - - Returns: - list of relevant materials and data - """ - - self.logger.info("Summary Builder Started") - - q = dict(self.query) - - mat_ids = self.materials.distinct(field=self.materials.key, criteria=q) - summary_ids = self.summary.distinct(field=self.summary.key, criteria=q) - - summary_set = set(mat_ids) - set(summary_ids) - - self.total = len(summary_set) - - self.logger.debug("Processing {} materials.".format(self.total)) - - for entry in summary_set: - materials_doc = self.materials.query_one({self.materials.key: entry}) - - valid_static_tasks = set( - [ - task_id - for task_id, task_type in materials_doc["task_types"].items() - if task_type == "Static" - ] - ) - set(materials_doc["deprecated_tasks"]) - - all_tasks = list(materials_doc["task_types"].keys()) - - data = { - HasProps.materials.value: materials_doc, - HasProps.thermo.value: self.thermo.query_one( - {self.materials.key: entry, "thermo_type": str(self.thermo_type)} - ), - HasProps.xas.value: list( - self.xas.query({self.xas.key: {"$in": all_tasks}}) - ), - HasProps.grain_boundaries.value: list( - self.grain_boundaries.query({self.grain_boundaries.key: entry}) - ), - HasProps.electronic_structure.value: self.electronic_structure.query_one( - {self.electronic_structure.key: entry} - ), - HasProps.magnetism.value: self.magnetism.query_one( - {self.magnetism.key: entry} - ), - HasProps.elasticity.value: self.elasticity.query_one( - {self.elasticity.key: {"$in": all_tasks}} - ), - HasProps.dielectric.value: self.dielectric.query_one( - {self.dielectric.key: entry} - ), - HasProps.piezoelectric.value: self.piezoelectric.query_one( - {self.piezoelectric.key: entry} - ), - HasProps.phonon.value: self.phonon.query_one( - {self.phonon.key: {"$in": all_tasks}}, - [self.phonon.key], - ), - HasProps.insertion_electrodes.value: list( - self.insertion_electrodes.query( - {"material_ids": entry}, - [self.insertion_electrodes.key], - ) - ), - HasProps.surface_properties.value: self.surfaces.query_one( - {self.surfaces.key: {"$in": all_tasks}} - ), - HasProps.substrates.value: list( - self.substrates.query( - {self.substrates.key: {"$in": all_tasks}}, [self.substrates.key] - ) - ), - HasProps.oxi_states.value: self.oxi_states.query_one( - {self.oxi_states.key: entry} - ), - HasProps.eos.value: self.eos.query_one( - {self.eos.key: {"$in": all_tasks}}, [self.eos.key] - ), - HasProps.chemenv.value: self.chemenv.query_one( - {self.chemenv.key: entry} - ), - HasProps.absorption.value: self.absorption.query_one( - {self.absorption.key: entry} - ), - HasProps.provenance.value: self.provenance.query_one( - {self.provenance.key: entry} - ), - HasProps.charge_density.value: self.charge_density_index.query_one( - {"task_id": {"$in": list(valid_static_tasks)}}, ["task_id"] - ), - } - - sub_fields = {} - - for collection, sub_field in sub_fields.items(): - if data[collection] is not None: - data[collection] = ( - data[collection][sub_field] - if (sub_field in data[collection]) - and (data[collection][sub_field] != {}) - else None - ) - - yield data - - def prechunk(self, number_splits: int): # pragma: no cover - """ - Prechunk method to perform chunking by the key field - """ - q = dict(self.query) - - keys = self.summary.newer_in(self.materials, criteria=q, exhaustive=True) - - N = ceil(len(keys) / number_splits) - for split in grouper(keys, N): - yield {"query": {self.materials.key: {"$in": list(split)}}} - - def process_item(self, item): - material_id = AlphaID(item[HasProps.materials.value]["material_id"]) - doc = SummaryDoc.from_docs(material_id=material_id, **item) - return jsanitize(doc.model_dump(exclude_none=False), allow_bson=True) - - def update_targets(self, items): - """ - Copy each summary doc to the store - - Args: - items ([dict]): A list of dictionaries of mpid document pairs to update - """ - items = list(filter(None, items)) - - if len(items) > 0: - self.logger.info("Inserting {} summary docs".format(len(items))) - self.summary.update(docs=items) - else: - self.logger.info("No summary entries to update") +class SummaryBuilderInputs(BaseModel): + """ + Input model for building summary documents. + + Bundles the property summary documents and property shim documents + needed to construct a single SummaryDoc. Property summary documents + contribute field values to the resulting SummaryDoc, while property + shim documents are used solely to populate the has_props mapping. + """ + + property_summary_docs: list[ + MaterialsSummary + | ThermoSummary + | XASSummary + | GBSummary + | ElectronicStructureSummary + | BandstructureSummary + | DosSummary + | MagnetismSummary + | ElasticitySummary + | DielectricSummary + | PiezoelectricSummary + | SurfacesSummary + | OxiStatesSummary + | ProvenanceSummary + ] + property_shim_docs: list[ + ChargeDensityData + | EosData + | PhononData + | AbsorptionData + | ElectrodesData + | SubstratesData + | ChemenvData + ] + + +def build_summary_docs( + input_documents: list[SummaryBuilderInputs], **kwargs +) -> list[SummaryDoc]: + """ + Generate summary documents from input property documents. + + Transforms a list of SummaryBuilderInputs into corresponding + SummaryDoc instances by merging property summary documents and + property shim documents for each material. Each SummaryDoc + aggregates fields from all provided property summaries and tracks + which properties are available via the has_props mapping. + + Caller is responsible for creating SummaryBuilderInputs instances + within their data pipeline context. + + Args: + input_documents: List of SummaryBuilderInputs documents to process. + + Returns: + list[SummaryDoc] + """ + return filter_map( + SummaryDoc.from_docs, + input_documents, + work_keys=["property_summary_docs", "property_shim_docs"], + **kwargs + ) diff --git a/emmet-builders/emmet/builders/materials/thermo.py b/emmet-builders/emmet/builders/materials/thermo.py index 599d5e6de8..0412cfcb8a 100644 --- a/emmet-builders/emmet/builders/materials/thermo.py +++ b/emmet-builders/emmet/builders/materials/thermo.py @@ -1,306 +1,121 @@ -from __future__ import annotations - +import logging import warnings -from datetime import datetime -from itertools import chain -from math import ceil -from typing import TYPE_CHECKING -from maggma.core import Builder, Store -from maggma.stores import S3Store -from maggma.utils import grouper -from monty.json import MontyDecoder -from pymatgen.analysis.phase_diagram import PhaseDiagramError +from pydantic import BaseModel, Field +from pymatgen.analysis.phase_diagram import PhaseDiagram, PhaseDiagramError from pymatgen.entries.computed_entries import ComputedStructureEntry from emmet.builders.utils import HiddenPrints from emmet.core.thermo import PhaseDiagramDoc, ThermoDoc -from emmet.core.utils import jsanitize - -if TYPE_CHECKING: - from collections.abc import Iterator - -warnings.warn( - f"The current version of {__name__}.ThermoBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, +from emmet.core.types.enums import ThermoType +from emmet.core.types.pymatgen_types.computed_entries_adapter import ( + ComputedStructureEntryType, ) +from emmet.core.vasp.calc_types.enums import RunType -class ThermoBuilder(Builder): - def __init__( - self, - thermo: Store, - corrected_entries: Store, - phase_diagram: Store | None = None, - query: dict | None = None, - num_phase_diagram_eles: int | None = None, - chunk_size: int = 1000, - **kwargs, - ): - """ - Calculates thermodynamic quantities for materials from phase - diagram constructions - - Args: - thermo (Store): Store of thermodynamic data such as formation - energy and decomposition pathway - corrected_entries (Store): Store of corrected entry data to use in thermo data and phase diagram - construction. This is required and should be built with the CorrectedEntriesBuilder. - phase_diagram (Store): Store of phase diagram data for each unique chemical system - query (dict): dictionary to limit materials to be analyzed - num_phase_diagram_eles (int): Maximum number of elements to use in phase diagram construction - for data within the separate phase_diagram collection - chunk_size (int): Size of chemsys chunks to process at any one time. - """ +class ThermoBuilderInput(BaseModel): + """ + Minimum inputs required to build ThermoDocs and PhaseDiagramDocs + for a chemical system. + """ - self.thermo = thermo - self.query = query if query else {} - self.corrected_entries = corrected_entries - self.phase_diagram = phase_diagram - self.num_phase_diagram_eles = num_phase_diagram_eles - self.chunk_size = chunk_size - self._completed_tasks: set[str] = set() + chemsys: str = Field( + ..., + description="Dash-delimited string of elements in the material.", + ) - if self.thermo.key != "thermo_id": - warnings.warn( - f"Key for the thermo store is incorrect and has been changed from {self.thermo.key} to thermo_id!" - ) - self.thermo.key = "thermo_id" - - if self.corrected_entries.key != "chemsys": - warnings.warn( - "Key for the corrected entries store is incorrect and has been changed " - f"from {self.corrected_entries.key} to chemsys!" - ) - self.corrected_entries.key = "chemsys" - - sources = [corrected_entries] - targets = [thermo] - - if self.phase_diagram is not None: - if self.phase_diagram.key != "phase_diagram_id": - warnings.warn( - f"Key for the phase diagram store is incorrect and has been changed from {self.phase_diagram.key} to phase_diagram_id!" # noqa: E501 - ) - self.phase_diagram.key = "phase_diagram_id" - - targets.append(phase_diagram) # type: ignore - - super().__init__( - sources=sources, targets=targets, chunk_size=chunk_size, **kwargs - ) + entries: dict[RunType | ThermoType, list[ComputedStructureEntryType]] = Field( + ..., + description=""" + List of all computed entries for 'chemsys' that are valid for the specified thermo type. + Entries for elemental endpoints of 'chemsys' are required. + """, + ) - def ensure_indexes(self): - """ - Ensures indicies on the tasks and materials collections - """ - # Search index for corrected_entries - self.corrected_entries.ensure_index("chemsys") - self.corrected_entries.ensure_index("last_updated") +class ThermoBuilderOutput(BaseModel): + """Output of build_thermo_docs_and_phase_diagram_docs function""" - # Search index for thermo - self.thermo.ensure_index("material_id") - self.thermo.ensure_index("thermo_id") - self.thermo.ensure_index("thermo_type") - self.thermo.ensure_index("last_updated") + chemsys: str + thermo_docs: dict[RunType | ThermoType, list[ThermoDoc] | None] + phase_diagram_docs: dict[RunType | ThermoType, PhaseDiagramDoc | None] - # Search index for thermo - self.thermo.ensure_index("material_id") - self.thermo.ensure_index("thermo_id") - self.thermo.ensure_index("thermo_type") - self.thermo.ensure_index("last_updated") - # Search index for phase_diagram - if self.phase_diagram: - coll = self.phase_diagram +ThermoPDPair = tuple[list[ThermoDoc] | None, PhaseDiagramDoc | None] - if isinstance(self.phase_diagram, S3Store): - coll = self.phase_diagram.index +logger = logging.getLogger() - coll.ensure_index("chemsys") - coll.ensure_index("phase_diagram_id") - def prechunk(self, number_splits: int) -> Iterator[dict]: # pragma: no cover - to_process_chemsys = self._get_chemsys_to_process() +def build_thermo_docs_and_phase_diagram_docs( + thermo_input: ThermoBuilderInput, +) -> ThermoBuilderOutput: + chemsys = thermo_input.chemsys - N = ceil(len(to_process_chemsys) / number_splits) - - for chemsys_chunk in grouper(to_process_chemsys, N): - yield {"query": {"chemsys": {"$in": list(chemsys_chunk)}}} - - def get_items(self) -> Iterator[list[dict]]: - """ - Gets whole chemical systems of entries to process - """ - - self.logger.info("Thermo Builder Started") - - self.logger.info("Setting indexes") - self.ensure_indexes() - - to_process_chemsys = self._get_chemsys_to_process() - - self.logger.info( - f"Found {len(to_process_chemsys)} chemical systems with new/updated materials to process" + thermo_docs = dict() + phase_diagram_docs = dict() + for thermo_type, entry_list in thermo_input.entries.items(): + logger.debug( + f"Processing {len(entry_list)} entries for: {chemsys} and thermo type: {thermo_type}" ) - self.total = len(to_process_chemsys) - # Yield the chemical systems in order of increasing size - # Will build them in a similar manner to fast Pourbaix - for chemsys in sorted( - to_process_chemsys, key=lambda x: len(x.split("-")), reverse=False - ): - corrected_entries = self.corrected_entries.query_one({"chemsys": chemsys}) - yield corrected_entries - - def process_item(self, item): - if not item: - return None - - pd_thermo_doc_pair_list = [] - - for thermo_type, entry_list in item["entries"].items(): - if entry_list: - entries = [ - ComputedStructureEntry.from_dict(entry) for entry in entry_list - ] - chemsys = item["chemsys"] - elements = chemsys.split("-") - - self.logger.debug( - f"Processing {len(entries)} entries for {chemsys} and thermo type {thermo_type}" + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + with HiddenPrints(): + _thermo_docs, _phase_diagram_doc = _produce_pair( + entry_list, thermo_type ) + thermo_docs[thermo_type] = _thermo_docs + phase_diagram_docs[thermo_type] = _phase_diagram_doc + + return ThermoBuilderOutput( + chemsys=chemsys, + thermo_docs=thermo_docs, + phase_diagram_docs=phase_diagram_docs, + ) + + +def _produce_pair( + computed_structure_entries: list[ComputedStructureEntry], + thermo_type: RunType | ThermoType, +) -> ThermoPDPair: + phase_diagram_doc = None + try: + phase_diagram: PhaseDiagram = ThermoDoc.construct_phase_diagram( + computed_structure_entries + ) + thermo_docs: list[ThermoDoc] = ThermoDoc.from_entries( + computed_structure_entries, + thermo_type, + phase_diagram, + use_max_chemsys=True, + deprecated=False, + ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - with HiddenPrints(): - pd_thermo_doc_pair_list.append( - self._produce_pair(entries, thermo_type, elements) - ) - - return pd_thermo_doc_pair_list - - def _produce_pair(self, pd_entries, thermo_type, elements): - # Produce thermo and phase diagram pair - - try: - # Obtain phase diagram - pd = ThermoDoc.construct_phase_diagram(pd_entries) - - # Iterate through entry material IDs and construct list of thermo docs to update - docs = ThermoDoc.from_entries( - pd_entries, thermo_type, pd, use_max_chemsys=True, deprecated=False - ) - - pd_docs = [None] - - if self.phase_diagram: - if ( - self.num_phase_diagram_eles is None - or len(elements) <= self.num_phase_diagram_eles - ): - chemsys = "-".join(sorted(set([e.symbol for e in pd.elements]))) - pd_id = "{}_{}".format(chemsys, str(thermo_type)) - pd_doc = PhaseDiagramDoc( - phase_diagram_id=pd_id, - chemsys=chemsys, - phase_diagram=pd, - thermo_type=thermo_type, - ) - - pd_data = jsanitize(pd_doc.model_dump(), allow_bson=True) - - pd_docs = [pd_data] - - docs_pd_pair = ( - jsanitize([d.model_dump() for d in docs], allow_bson=True), - pd_docs, + if phase_diagram: + chemsys = "-".join( + sorted(set([el.symbol for el in phase_diagram.elements])) ) - - return docs_pd_pair - - except PhaseDiagramError as p: - elsyms = [] - for e in pd_entries: - elsyms.extend([el.symbol for el in e.composition.elements]) - - self.logger.error( - f"Phase diagram error in chemsys {'-'.join(sorted(set(elsyms)))}: {p}" + phase_diagram_id = f"{chemsys}_{thermo_type.value}" + phase_diagram_doc = PhaseDiagramDoc( + phase_diagram_id=phase_diagram_id, + chemsys=chemsys, + phase_diagram=phase_diagram, + thermo_type=thermo_type, ) - return (None, None) - - def update_targets(self, items): - """ - Inserts the thermo and phase diagram docs into the thermo collection - Args: - items ([[tuple(list[dict],list[dict])]]): a list of a list of thermo and phase diagram dict pairs to update - """ - - thermo_docs = [pair[0] for pair_list in items for pair in pair_list] - phase_diagram_docs = [pair[1] for pair_list in items for pair in pair_list] - - # flatten out lists - thermo_docs = list(filter(None, chain.from_iterable(thermo_docs))) - phase_diagram_docs = list(filter(None, chain.from_iterable(phase_diagram_docs))) - - # Check if already updated this run - thermo_docs = [ - i for i in thermo_docs if i["thermo_id"] not in self._completed_tasks - ] - - self._completed_tasks |= {i["thermo_id"] for i in thermo_docs} - for item in thermo_docs: - if isinstance(item["last_updated"], dict): - item["last_updated"] = MontyDecoder().process_decoded( - item["last_updated"] - ) - - if self.phase_diagram: - self.phase_diagram.update(phase_diagram_docs) - - if len(thermo_docs) > 0: - self.logger.info(f"Updating {len(thermo_docs)} thermo documents") - self.thermo.update(docs=thermo_docs, key=["thermo_id"]) - else: - self.logger.info("No thermo items to update") + if not thermo_docs: + return None, phase_diagram_doc - def _get_chemsys_to_process(self): - # Use last-updated to find new chemsys - corrected_entries_chemsys_dates = { - d[self.corrected_entries.key]: d[self.corrected_entries.last_updated_field] - for d in self.corrected_entries.query( - self.query, - properties=[ - self.corrected_entries.key, - self.corrected_entries.last_updated_field, - ], - ) - } + return thermo_docs, phase_diagram_doc - thermo_chemsys_dates = {} - for d in self.thermo.query( - {"deprecated": False}, - properties=[self.corrected_entries.key, self.thermo.last_updated_field], - ): - entry = thermo_chemsys_dates.get(d[self.corrected_entries.key], None) - if entry is None or d[self.thermo.last_updated_field] < entry: - thermo_chemsys_dates[d[self.corrected_entries.key]] = d[ - self.thermo.last_updated_field - ] + except PhaseDiagramError as p: + elsyms = [] + for entry in computed_structure_entries: + elsyms.extend([el.symbol for el in entry.composition.elements]) - to_process_chemsys = [ - chemsys - for chemsys in corrected_entries_chemsys_dates - if (chemsys not in thermo_chemsys_dates) - or ( - thermo_chemsys_dates[chemsys] - < datetime.fromisoformat(corrected_entries_chemsys_dates[chemsys]) - ) - ] + logger.error( + f"Phase diagram error in chemsys {'-'.join(sorted(set(elsyms)))}: {p}" + ) - return to_process_chemsys + return None, None diff --git a/emmet-builders/emmet/builders/matscholar/missing_compositions.py b/emmet-builders/emmet/builders/matscholar/missing_compositions.py deleted file mode 100644 index ce833cb2a8..0000000000 --- a/emmet-builders/emmet/builders/matscholar/missing_compositions.py +++ /dev/null @@ -1,231 +0,0 @@ -import itertools -import warnings -from itertools import combinations -from math import ceil -from typing import TYPE_CHECKING - -from maggma.core import Builder -from maggma.stores import MongoStore, MongoURIStore, S3Store -from maggma.utils import grouper -from pymatgen.core import Composition, Element - -if TYPE_CHECKING: - from collections.abc import Iterator - -warnings.warn( - f"The current version of {__name__}.MissingCompositionsBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class MissingCompositionsBuilder(Builder): - """ - Builder that finds compositions not found in the - Materials Project for each chemical system. - Based on the Text Mining project in MPContribs. - """ - - def __init__( - self, - phase_diagram: S3Store, - mpcontribs: MongoURIStore, - missing_compositions: MongoStore, - query: dict | None = None, - **kwargs, - ): - """ - Arguments: - phase_diagram: source store for chemsys data - matsholar_store: source store for matscholar data - missing_compositions: Target store to save the missing compositions - query: dictionary to query the phase diagram store - **kwargs: Additional keyword arguments - """ - self.phase_diagram = phase_diagram - self.mpcontribs = mpcontribs - self.missing_compositions = missing_compositions - self.query = query - self.kwargs = kwargs - # TODO: make sure the two lines below are needed? - self.phase_diagram.key = "phase_diagram_id" - self.missing_compositions.key = "chemical_system" - - super().__init__( - sources=[phase_diagram, mpcontribs], - targets=[missing_compositions], - **kwargs, - ) - - def prechunk(self, number_splits: int) -> Iterator[dict]: # pragma: no cover - """ - Prechunk method to perform chunking by the key field - """ - q = self.query or {} # type: ignore - - keys = self.missing_compositions.newer_in( - self.phase_diagram, criteria=q, exhaustive=True - ) - - N = ceil(len(keys) / number_splits) - for split in grouper(keys, N): - yield {"query": {self.phase_diagram.key: {"$in": list(split)}}} - - def get_items(self) -> Iterator[dict]: - """ - Returns all chemical systems (combinations of elements) - to process. - Enumarates all chemical systems and queries the - phase diagram for each system, in the case where - the chemical system is not found in the phase diagram, - it returns a dictionary with the chemical system - and an empty list for the missing compositions - """ - self.logger.info("Missing Composition Builder Started") - self.logger.info("Setting up chemical systems to process") - elements = set() - # get all elements - elements = set([e.symbol for e in Element]) - - # Generate all unique combinations of elements to form chemical systems - chemical_systems = [] - for r in range(2, 5): - for combination in combinations(elements, r): - system = "-".join(sorted([str(element) for element in combination])) - chemical_systems.append(system) - q = self.query or {} - projection = { - "chemsys": 1, - "phase_diagram.all_entries.composition": 1, - } - for sys in chemical_systems: - q.update({"chemsys": sys}) - self.logger.info(f"Querying phase diagram for {sys}") - try: - items = self.phase_diagram.query(criteria=q, properties=projection) - # combine all entries from all phase diagrams - all_entries = [] - for item in items: - all_entries = [ - i["composition"] for i in item["phase_diagram"]["all_entries"] - ] - - # Find missing compositions - matscholar_entries = self._get_entries_in_chemsys(sys) - doc = { - "chemsys": sys, - "all_compositions": all_entries, - "matscholar_entries": matscholar_entries, - } - yield doc - except Exception as ex: - self.logger.error(f"Erro looking for phase diagram for {sys}: {ex}") - continue - - def process_item(self, item: dict) -> dict: - """ - Processes a chemical system and finds missing c - ompositions for that system. - Note that it returns a missing_compositions dict - regardless of whether there is a missing composition, - in which case, it contains an empty dictionary for - the missing_composition_entries field. - """ - compositions = set() - chemsys = item["chemsys"] - matscholar_entries = item["matscholar_entries"] - self.logger.info( - "Querying entries in MPContribs matscholar" - f"project for the chemical system {chemsys}" - ) - missing_compositions = { - "chemical_system": chemsys, - "missing_composition_entries": {}, - } - - if len(item["all_compositions"]) > 0: - # Get the compositions from retrieved entries, - # and use its reduced_formula - for entry in item["all_compositions"]: - composition = Composition.from_dict(entry) - # Note the reduced formula is a string - # instead of a Composition object - compositions.add(composition.reduced_formula) - - if len(matscholar_entries) == 0: - self.logger.info( - "No entries found in MPContribs" "for the chemical system" - ) - - else: - self.logger.info( - f"Found {len(matscholar_entries)}" - "entries in MPContribs for the chemical system" - ) - - for entry in matscholar_entries: - # Comparing reduced formulae from MPContribs - # and Phase Diagram - if ( - Composition(entry["formula"]).reduced_formula - not in compositions - ): - # this formula doesn't exist in the dictionary, - # make an entry in the missing_compositions dict - if ( - entry["formula"] - not in missing_compositions[ - "missing_composition_entries" - ].keys() - ): - missing_compositions["missing_composition_entries"].update( - { - entry["formula"]: [ - {"link": entry["link"], "doi": entry["doi"]} - ] - } - ) - # formula already exists in the dictionary, append the new entry - else: - missing_compositions["missing_composition_entries"][ - entry["formula"] - ].append({"link": entry["link"], "doi": entry["doi"]}) - - return missing_compositions - - def update_targets(self, items): - """ - Updates the target store with the missing compositions - """ - docs = list(filter(None, items)) - - if len(docs) > 0: - self.logger.info(f"Found {len(docs)} chemical-system docs to update") - self.missing_compositions.update(items) - else: - self.logger.info("No items to update") - - def _get_entries_in_chemsys(self, chemsys) -> list: - """Queries the MPContribs Store for entries in a chemical system.""" - # get sub-systems - chemsys_subsystems = [] - elements = chemsys.split("-") - # get all possible combinations - for i in range(2, len(elements) + 1): - chemsys_subsystems += [ - "-".join(sorted(c)) for c in itertools.combinations(elements, i) - ] - - results = [] - for subsystem in chemsys_subsystems: - try: - query = {"project": "matscholar", "data.chemsys": subsystem} - fields = ["formula", "data"] - entries = self.mpcontribs.query(criteria=query, properties=fields) - for entry in entries: - results.append({"formula": entry["formula"], **entry["data"]}) - except Exception as ex: - self.logger.error(f"Error querying MPContribs for {subsystem}: {ex}") - return results diff --git a/emmet-builders/emmet/builders/mobility/__init__.py b/emmet-builders/emmet/builders/mobility/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/emmet-builders/emmet/builders/mobility/migration_graph.py b/emmet-builders/emmet/builders/mobility/migration_graph.py deleted file mode 100644 index 6528b62d1b..0000000000 --- a/emmet-builders/emmet/builders/mobility/migration_graph.py +++ /dev/null @@ -1,112 +0,0 @@ -import warnings - -from maggma.builders.map_builder import MapBuilder -from maggma.stores import MongoStore -from pymatgen.analysis.diffusion.neb.full_path_mapper import MigrationGraph -from pymatgen.apps.battery.insertion_battery import InsertionElectrode - -from emmet.builders.utils import get_hop_cutoff -from emmet.core.mobility.migrationgraph import MigrationGraphDoc -from emmet.core.utils import jsanitize - -warnings.warn( - f"The current version of {__name__}.MigrationGraphBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class MigrationGraphBuilder(MapBuilder): - def __init__( - self, - insertion_electrode: MongoStore, - migration_graph: MongoStore, - algorithm: str = "hops_based", - min_hop_distance: float = 1, - max_hop_distance: float = 7, - populate_sc_fields: bool = True, - min_length_sc: float = 8, - minmax_num_atoms: tuple[int, int] = (80, 120), - ltol: float = 0.2, - stol: float = 0.3, - angle_tol: float = 5, - **kwargs, - ): - self.insertion_electrode = insertion_electrode - self.migration_graph = migration_graph - self.algorithm = algorithm - self.min_hop_distance = min_hop_distance - self.max_hop_distance = max_hop_distance - self.populate_sc_fields = populate_sc_fields - self.min_length_sc = min_length_sc - self.minmax_num_atoms = minmax_num_atoms - self.ltol = ltol - self.stol = stol - self.angle_tol = angle_tol - super().__init__(source=insertion_electrode, target=migration_graph, **kwargs) - self.connect() - - def unary_function(self, item): - warnings = [] - - # get entries and info from insertion electrode - ie = InsertionElectrode.from_dict(item["electrode_object"]) - entries = ie.get_all_entries() - wi_entry = ie.working_ion_entry - - # get migration graph structure - structs = MigrationGraph.get_structure_from_entries(entries, wi_entry) - if len(structs) == 0: - warnings.append("cannot generate migration graph from entries") - d = None - else: - if len(structs) > 1: - warnings.append( - f"migration graph ambiguous: {len(structs)} possible options" - ) - # get hop cutoff distance - d = get_hop_cutoff( - migration_graph_struct=structs[0], - mobile_specie=wi_entry.composition.chemical_system, - algorithm=self.algorithm, - min_hop_distance=self.min_hop_distance, - max_hop_distance=self.max_hop_distance, - ) - - # get migration graph doc - try: - mg_doc = MigrationGraphDoc.from_entries_and_distance( - battery_id=item["battery_id"], - grouped_entries=entries, - working_ion_entry=wi_entry, - hop_cutoff=d, - populate_sc_fields=self.populate_sc_fields, - min_length_sc=self.min_length_sc, - minmax_num_atoms=self.minmax_num_atoms, - ltol=self.ltol, - stol=self.stol, - angle_tol=self.angle_tol, - warnings=warnings, - ) - except Exception as e: - mg_doc = MigrationGraphDoc( - battery_id=item["battery_id"], - entries_for_generation=entries, - working_ion_entry=wi_entry, - hop_cutoff=d, - migration_graph=None, - populate_sc_fields=self.populate_sc_fields, - min_length_sc=self.min_length_sc, - minmax_num_atoms=self.minmax_num_atoms, - ltol=self.ltol, - stol=self.stol, - angle_tol=self.angle_tol, - warnings=warnings, - deprecated=True, - ) - self.logger.error(f"error getting MigrationGraphDoc: {e}") - return jsanitize(mg_doc) - - return jsanitize(mg_doc.model_dump()) diff --git a/emmet-builders/emmet/builders/utils.py b/emmet-builders/emmet/builders/utils.py index 59b82f9271..51e305b853 100644 --- a/emmet-builders/emmet/builders/utils.py +++ b/emmet-builders/emmet/builders/utils.py @@ -6,6 +6,15 @@ from gzip import GzipFile from io import BytesIO from itertools import chain, combinations +from typing import ( + TYPE_CHECKING, + Callable, + Iterable, + Iterator, + Mapping, + ParamSpec, + TypeVar, +) import orjson from botocore.exceptions import ClientError @@ -14,11 +23,8 @@ from pymatgen.core import Structure from pymatgen.io.vasp.inputs import PotcarSingle -from emmet.core.types.typing import FSPathType - from emmet.builders.settings import EmmetBuildSettings - -from typing import TYPE_CHECKING +from emmet.core.types.typing import FSPathType if TYPE_CHECKING: from typing import Any, Literal @@ -300,3 +306,104 @@ def get_potcar_stats( stats[calc_type].update({potcar_symbol: summary_stats}) return stats + + +# ----------------------------------------------------------------------------- +# Generics +# ----------------------------------------------------------------------------- + +T = TypeVar("T") +S = TypeVar("S") +P = ParamSpec("P") +V = TypeVar("V") + + +def try_call( + fn: Callable[P, T], + /, + *args: P.args, + _default: S = None, + _safe: bool = True, + **kwargs: P.kwargs, +) -> T | S | None: + """Attempt to call a function, returning a _default value if an exception is raised. + + Args: + fn: The function to call. + *args: Positional arguments to forward to ``fn``. + _default: The value to return if ``fn`` raises an exception. + Defaults to ``None``. + _safe: Override behavior of ``try_call`` — propagate exceptions when + ``fn`` raises. Useful for debugging. + **kwargs: Keyword arguments to forward to ``fn``. + + Returns: + The return value of ``fn(*args, **kwargs)`` if successful, + otherwise ``_default``. + """ + if not _safe: + return fn(*args, **kwargs) + try: + return fn(*args, **kwargs) + except Exception: + return _default + + +def filter_map( + fn: Callable[..., T], + work: Iterable[V], + /, + *args: Any, + work_keys: list[str] | None = None, + **kwargs: Any, +) -> Iterator[T]: + """Apply a function to each item in an iterable, yielding non-None results. + + Lazily maps ``fn`` over ``work``, passing each item as the first argument + along with any additional positional and keyword arguments. Results that + are ``None`` are excluded. + + When ``work_keys`` is provided, each item in ``work`` is not passed as a + positional argument. Instead, the specified keys are extracted from each + item (via attribute access or dict lookup) and forwarded to ``fn`` as + keyword arguments, merged with any extra ``**kwargs``. + + Args: + fn: The function to apply to each item in ``work``. + work: The iterable of items to process. + *args: Additional positional arguments to forward to ``fn``. + work_keys: If provided, a list of keys/attributes to extract from + each item in ``work`` and pass as keyword arguments to ``fn``. + **kwargs: Additional keyword arguments to forward to ``fn``. + + Yields: + Non-``None`` results from applying ``fn`` to each item in ``work``. + """ + + def _extract_kwargs(item: Any, keys: list[str]) -> dict[str, Any]: + return { + key: item[key] if isinstance(item, Mapping) else getattr(item, key) + for key in keys + } + + if work_keys is not None: + yield from filter( + lambda y: y is not None, + map( + lambda x: try_call( + fn, + *args, + try_call(_extract_kwargs, x, work_keys, **kwargs), + **kwargs, + ), + work, + ), + ) + else: + yield from filter( + lambda y: y is not None, + map( + lambda x: try_call(fn, x, *args, **kwargs), + work, + ), + ) diff --git a/emmet-builders/emmet/builders/vasp/__init__.py b/emmet-builders/emmet/builders/vasp/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/emmet-builders/emmet/builders/vasp/materials.py b/emmet-builders/emmet/builders/vasp/materials.py index 3f6a801a4c..d75a729820 100644 --- a/emmet-builders/emmet/builders/vasp/materials.py +++ b/emmet-builders/emmet/builders/vasp/materials.py @@ -1,368 +1,112 @@ -from __future__ import annotations - -import warnings -from datetime import datetime -from itertools import chain -from math import ceil -from typing import TYPE_CHECKING - -from maggma.builders import Builder -from maggma.stores import Store -from maggma.utils import grouper +from itertools import groupby +from typing import Iterator from emmet.builders.settings import EmmetBuildSettings -from emmet.core.tasks import TaskDoc -from emmet.core.utils import group_structures, jsanitize, undeform_structure +from emmet.core.tasks import CoreTaskDoc +from emmet.core.utils import group_structures, undeform_structure from emmet.core.vasp.calc_types import TaskType from emmet.core.vasp.material import MaterialsDoc +from pydantic import Field -if TYPE_CHECKING: - from collections.abc import Iterable, Iterator - -__author__ = "Shyam Dwaraknath " - -SETTINGS = EmmetBuildSettings() -warnings.warn( - f"The current version of {__name__}.MaterialsBuilder will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class MaterialsBuilder(Builder): +class ValidationTaskDoc(CoreTaskDoc): """ - The Materials Builder matches VASP task documents by structure similarity into - materials document. The purpose of this builder is group calculations and determine - the best structure. All other properties are derived from other builders. - - The process is as follows: - - 1.) Find all documents with the same formula - 2.) Select only task documents for the task_types we can select properties from - 3.) Aggregate task documents based on structure similarity - 4.) Create a MaterialDoc from the group of task documents - 5.) Validate material document - + Wrapper for TaskDoc to ensure compatiblity with validation checks + in MaterialsDoc.from_tasks(...) if validation builder is skipped """ - def __init__( - self, - tasks: Store, - materials: Store, - task_validation: Store | None = None, - query: dict | None = None, - settings: EmmetBuildSettings | None = None, - **kwargs, - ): - """ - Args: - tasks: Store of task documents - materials: Store of materials documents to generate - task_validation: Store for storing task validation results - query: dictionary to limit tasks to be analyzed - settings: EmmetSettings to use in the build process - """ - - self.tasks = tasks - self.materials = materials - self.task_validation = task_validation - self.query = query if query else {} - self.settings = EmmetBuildSettings.autoload(settings) - self.kwargs = kwargs - - sources = [tasks] - if self.task_validation: - sources.append(self.task_validation) - super().__init__(sources=sources, targets=[materials], **kwargs) - - def ensure_indexes(self): - """ - Ensures indices on the tasks and materials collections - """ - - # Basic search index for tasks - self.tasks.ensure_index("task_id") - self.tasks.ensure_index("last_updated") - self.tasks.ensure_index("state") - self.tasks.ensure_index("formula_pretty") - - # Search index for materials - self.materials.ensure_index("material_id") - self.materials.ensure_index("last_updated") - self.materials.ensure_index("task_ids") - - if self.task_validation: - self.task_validation.ensure_index("task_id") - self.task_validation.ensure_index("valid") - - def prechunk(self, number_splits: int) -> Iterable[dict]: # pragma: no cover - """Prechunk the materials builder for distributed computation""" - temp_query = dict(self.query) - temp_query["state"] = "successful" - if len(self.settings.BUILD_TAGS) > 0 and len(self.settings.EXCLUDED_TAGS) > 0: - temp_query["$and"] = [ - {"tags": {"$in": self.settings.BUILD_TAGS}}, - {"tags": {"$nin": self.settings.EXCLUDED_TAGS}}, - ] - elif len(self.settings.BUILD_TAGS) > 0: - temp_query["tags"] = {"$in": self.settings.BUILD_TAGS} - - self.logger.info("Finding tasks to process") - all_tasks = list( - self.tasks.query(temp_query, [self.tasks.key, "formula_pretty"]) - ) - - processed_tasks = set(self.materials.distinct("task_ids")) - to_process_tasks = {d[self.tasks.key] for d in all_tasks} - processed_tasks - to_process_forms = { - d["formula_pretty"] - for d in all_tasks - if d[self.tasks.key] in to_process_tasks - } - - N = ceil(len(to_process_forms) / number_splits) - - for formula_chunk in grouper(to_process_forms, N): - yield {"query": {"formula_pretty": {"$in": list(formula_chunk)}}} - - def get_items(self) -> Iterator[list[dict]]: - """ - Gets all items to process into materials documents. - This does no datetime checking; relying on whether - task_ids are included in the Materials Collection - - Returns: - generator or list relevant tasks and materials to process into materials - documents - """ - - task_types = [t.value for t in self.settings.VASP_ALLOWED_VASP_TYPES] - self.logger.info("Materials builder started") - self.logger.info(f"Allowed task types: {task_types}") - - self.logger.info("Setting indexes") - self.ensure_indexes() - - # Save timestamp to mark buildtime for material documents - self.timestamp = datetime.utcnow() - - # Get all processed tasks: - temp_query = dict(self.query) - temp_query["state"] = "successful" - if len(self.settings.BUILD_TAGS) > 0 and len(self.settings.EXCLUDED_TAGS) > 0: - temp_query["$and"] = [ - {"tags": {"$in": self.settings.BUILD_TAGS}}, - {"tags": {"$nin": self.settings.EXCLUDED_TAGS}}, - ] - elif len(self.settings.BUILD_TAGS) > 0: - temp_query["tags"] = {"$in": self.settings.BUILD_TAGS} - - self.logger.info("Finding tasks to process") - all_tasks = list( - self.tasks.query(temp_query, [self.tasks.key, "formula_pretty"]) - ) - - processed_tasks = set(self.materials.distinct("task_ids")) - to_process_tasks = {d[self.tasks.key] for d in all_tasks} - processed_tasks - to_process_forms = { - d["formula_pretty"] - for d in all_tasks - if d[self.tasks.key] in to_process_tasks - } - - self.logger.info(f"Found {len(to_process_tasks)} unprocessed tasks") - self.logger.info(f"Found {len(to_process_forms)} unprocessed formulas") - - # Set total for builder bars to have a total - self.total = len(to_process_forms) + is_valid: bool = Field(True) - if self.task_validation: - invalid_ids = { - doc[self.tasks.key] - for doc in self.task_validation.query( - {"valid": False}, [self.task_validation.key] - ) - } - else: - invalid_ids = set() - - projected_fields = [ - # "last_updated", - "completed_at", - "task_id", - "formula_pretty", - "output.energy_per_atom", - "output.structure", - "input.parameters", - # needed for run_type and task_type - "calcs_reversed.input.parameters", - "calcs_reversed.input.incar", - "calcs_reversed.run_type", - "orig_inputs", - "input.structure", - # needed for entry from task_doc - "output.energy", - "calcs_reversed.output.energy", - "input.is_hubbard", - "input.hubbards", - "calcs_reversed.input.potcar_spec", - "calcs_reversed.output.structure", - # needed for transform deformation structure back for grouping - "transformations", - # misc info for materials doc - "tags", - ] - - for formula in to_process_forms: - tasks_query = dict(temp_query) - tasks_query["formula_pretty"] = formula - tasks = list( - self.tasks.query(criteria=tasks_query, properties=projected_fields) - ) - for t in tasks: - t["is_valid"] = t[self.tasks.key] not in invalid_ids - - yield tasks - - def process_item(self, items: list[dict]) -> list[dict]: - """ - Process the tasks into a list of materials - - Args: - tasks [dict]: a list of task docs - - Returns: - ([dict],list): a list of new materials docs and a list of task_ids that - were processed - """ - tasks = [ - TaskDoc(**task) for task in items - ] # [TaskDoc(**task) for task in items] - formula = tasks[0].formula_pretty - task_ids = [task.task_id for task in tasks] - - # not all tasks contains transformation information - task_transformations = [task.get("transformations", None) for task in items] - - self.logger.debug(f"Processing {formula}: {task_ids}") +def build_material_docs( + input_documents: list[ValidationTaskDoc], + settings: EmmetBuildSettings = EmmetBuildSettings(), +) -> list[MaterialsDoc]: + """ + Aggregate ValidationTaskDocs into MaterialsDocs by chemical formula. + Caller is responsible for creating ValidationTaskDoc instances within + their data pipeline context. + + Groups input documents by formula_pretty, performs structure matching + on each formula group, and constructs a MaterialsDoc for each group of + task documents with matching structures within each formula group. + + Args: + input_documents: List of ValidationTaskDoc objects to process. Must contain + ALL documents for each unique formula_pretty value to avoid incorrect + material splitting. Documents for the same formula should not be split + across multiple function calls. + settings: Builder configuration settings, defaults defined in EmmetBuildSettings. + Relevant settings: VASP_STRUCTURE_QUALITY_SCORES, VASP_USE_STATICS, + VASP_ALLOWED_VASP_TYPES, LTOL, STOL, ANGLE_TOL, and SYMPREC. + + Returns: + list[MaterialsDoc] + """ - grouped_tasks = self.filter_and_group_tasks(tasks, task_transformations) - materials = [] + input_documents.sort(key=lambda x: x.formula_pretty) + materials = [] + for _, group in groupby(input_documents, key=lambda x: x.formula_pretty): + # TODO: logging - task_ids = [task.task_id for task in group] + group = list(group) + task_transformations = [task.transformations for task in group] + grouped_tasks = filter_and_group_tasks(group, task_transformations, settings) for group in grouped_tasks: - # commercial_license == True means that the default CC-BY license is applied - # commercial_license == False means that a CC-BY-NC license is applied - commercial_license = True - for task_doc in group: - if task_doc.tags and set(task_doc.tags).intersection( - set(self.settings.NON_COMMERCIAL_TAGS) - ): - commercial_license = False - break try: - materials.append( - MaterialsDoc.from_tasks( - group, - structure_quality_scores=self.settings.VASP_STRUCTURE_QUALITY_SCORES, - use_statics=self.settings.VASP_USE_STATICS, - commercial_license=commercial_license, - ) + doc = MaterialsDoc.from_tasks( + group, + structure_quality_scores=settings.VASP_STRUCTURE_QUALITY_SCORES, + use_statics=settings.VASP_USE_STATICS, ) + materials.append(doc) except Exception as e: - failed_ids = list({t_.task_id for t_ in group}) - doc = MaterialsDoc.construct_deprecated_material( - group, commercial_license - ) + # TODO: logging - failed_ids = list({t_.task_id for t_ in group}) + doc = MaterialsDoc.construct_deprecated_material(group) doc.warnings.append(str(e)) materials.append(doc) - self.logger.warn( - f"Failed making material for {failed_ids}." - f" Inserted as deprecated Material: {doc.material_id}" - ) - - self.logger.debug(f"Produced {len(materials)} materials for {formula}") - return jsanitize([mat.model_dump() for mat in materials], allow_bson=True) + return materials - def update_targets(self, items: list[list[dict]]): - """ - Inserts the new task_types into the task_types collection - Args: - items ([([dict],[int])]): A list of tuples of materials to update and the - corresponding processed task_ids - """ +def filter_and_group_tasks( + tasks: list[ValidationTaskDoc], + task_transformations: list[dict | None], + settings: EmmetBuildSettings, +) -> Iterator[list[ValidationTaskDoc]]: + """Groups tasks by structure matching""" - docs = list(chain.from_iterable(items)) # type: ignore - - for item in docs: - item.update({"_bt": self.timestamp}) - - material_ids = list({item["material_id"] for item in docs}) - - if len(items) > 0: - self.logger.info(f"Updating {len(docs)} materials") - self.materials.remove_docs({self.materials.key: {"$in": material_ids}}) - self.materials.update(docs=docs, key=["material_id"]) - else: - self.logger.info("No items to update") - - def filter_and_group_tasks( - self, tasks: list[TaskDoc], task_transformations: list[dict | None] - ) -> Iterator[list[TaskDoc]]: - """ - Groups tasks by structure matching - """ - - filtered_tasks = [] - filtered_transformations = [] - for task, transformations in zip(tasks, task_transformations): - if any( - allowed_type == task.task_type - for allowed_type in self.settings.VASP_ALLOWED_VASP_TYPES - ): - filtered_tasks.append(task) - filtered_transformations.append(transformations) - - structures = [] - for idx, (task, transformations) in enumerate( - zip(filtered_tasks, filtered_transformations) + filtered_tasks = [] + filtered_transformations = [] + for task, transformations in zip(tasks, task_transformations): + if any( + allowed_type == task.task_type + for allowed_type in settings.VASP_ALLOWED_VASP_TYPES ): - if task.task_type == TaskType.Deformation: - if ( - transformations is None - or not task.input - or not task.input.structure - ): # Do not include deformed tasks without transformation information - self.logger.debug( - "Cannot find transformation or original structure " - f"for deformation task {task.task_id}. Excluding task." - ) - continue - else: - s = undeform_structure(task.input.structure, transformations) - - elif task.output and task.output.structure: - s = task.output.structure # type: ignore[assignment] - else: - self.logger.debug( - f"Skipping task {task.task_id}, missing output structure." - ) + filtered_tasks.append(task) + filtered_transformations.append(transformations) + structures = [] + for idx, (task, transformations) in enumerate( + zip(filtered_tasks, filtered_transformations) + ): + if task.task_type == TaskType.Deformation: + if transformations is None: + # Do not include deformed tasks without transformation information continue - - s.index: int = idx # type: ignore - structures.append(s) - - grouped_structures = group_structures( - structures, - ltol=self.settings.LTOL, - stol=self.settings.STOL, - angle_tol=self.settings.ANGLE_TOL, - symprec=self.settings.SYMPREC, - ) - for group in grouped_structures: - grouped_tasks = [filtered_tasks[struct.index] for struct in group] # type: ignore - yield grouped_tasks + else: + s = undeform_structure(task.input.structure, transformations) + else: + s = task.output.structure + + s.index = idx + structures.append(s) + + grouped_structures = group_structures( + structures, + ltol=settings.LTOL, + stol=settings.STOL, + angle_tol=settings.ANGLE_TOL, + symprec=settings.SYMPREC, + ) + for group in grouped_structures: + grouped_tasks = [filtered_tasks[struct.index] for struct in group] + yield grouped_tasks diff --git a/emmet-builders/emmet/builders/vasp/task_validator.py b/emmet-builders/emmet/builders/vasp/task_validator.py index 58f7bbd6c5..59fd52cf7d 100644 --- a/emmet-builders/emmet/builders/vasp/task_validator.py +++ b/emmet-builders/emmet/builders/vasp/task_validator.py @@ -1,99 +1,43 @@ -import warnings - -from maggma.builders import MapBuilder -from maggma.core import Store +from typing import Any from emmet.builders.settings import EmmetBuildSettings -from emmet.builders.utils import get_potcar_stats -from emmet.core.tasks import TaskDoc -from emmet.core.types.enums import DeprecationMessage -from emmet.core.vasp.calc_types.enums import CalcType +from emmet.builders.utils import filter_map, get_potcar_stats +from emmet.core.tasks import CoreTaskDoc, TaskDoc +from emmet.core.vasp.task_valid import TaskDocument from emmet.core.vasp.validation_legacy import ValidationDoc -warnings.warn( - f"The current version of {__name__}.TaskValidator will be deprecated in version 0.87.0. " - "To continue using legacy builders please install emmet-builders-legacy from git. A PyPI " - "release for emmet-legacy-builders is not planned.", - DeprecationWarning, - stacklevel=2, -) - - -class TaskValidator(MapBuilder): - def __init__( - self, - tasks: Store, - task_validation: Store, - potcar_stats: dict[CalcType, dict[str, str]] | None = None, - settings: EmmetBuildSettings | None = None, - query: dict | None = None, - **kwargs, - ): - """ - Creates task_types from tasks and type definitions - - Args: - tasks: Store of task documents - task_validation: Store of task_types for tasks - potcar_stats: Optional dictionary of potcar hash data. - Mapping is calculation type -> potcar symbol -> hash value. - """ - self.tasks = tasks - self.task_validation = task_validation - self.settings = EmmetBuildSettings.autoload(settings) - self.query = query - self.kwargs = kwargs - self.potcar_stats = potcar_stats - - # Set up potcar cache if appropriate - if self.settings.VASP_VALIDATE_POTCAR_STATS: - if not self.potcar_stats: - self.potcar_stats = get_potcar_stats(method="stored") - else: - self.potcar_stats = None - super().__init__( - source=tasks, - target=task_validation, - projection=[ - "orig_inputs", - "input.hubbards", - "output.structure", - "output.bandgap", - "chemsys", - "calcs_reversed", - ], - query=query, - **kwargs, +def build_validation_doc( + input_documents: list[CoreTaskDoc | TaskDoc | TaskDocument], + settings: EmmetBuildSettings = EmmetBuildSettings(), + potcar_stats: dict[str, Any] = get_potcar_stats(method="stored"), + **kwargs +) -> list[ValidationDoc]: + """ + Build a ValidationDoc from a CoreTaskDoc by checking CoreTaskDoc + parameters against reference values. + + Args: + input: List of parsed task document to validate. + settings: Reference values used in validation, defaults defined in EmmetBuildSettings. + Relevant settings: VASP_KSPACING_TOLERANCE, VASP_DEFAULT_INPUT_SETS, VASP_CHECKED_LDAU_FIELDS, + VASP_MAX_SCF_GRADIENT, and DEPRECATED_TAGS. + potcar_stats: POTCAR stats used to validate POTCARs used for the source calculation + for 'input'. Defaults to compiled values in 'emmet.builders.vasp.mp_potcar_stats.json.gz' + + Returns: + list[ValidationDoc] + """ + return list( + filter_map( + ValidationDoc.from_task_doc, + input_documents, + kspacing_tolerance=settings.VASP_KSPACING_TOLERANCE, + input_sets=settings.VASP_DEFAULT_INPUT_SETS, + LDAU_fields=settings.VASP_CHECKED_LDAU_FIELDS, + max_allowed_scf_gradient=settings.VASP_MAX_SCF_GRADIENT, + potcar_stats=potcar_stats, + bad_tags=settings.DEPRECATED_TAGS, + **kwargs ) - - def unary_function(self, item): - """ - Find the task_type for the item - - Args: - item (dict): a (projection of a) task doc - """ - task_doc = TaskDoc(**item) - validation_doc = ValidationDoc.from_task_doc( - task_doc=task_doc, - kpts_tolerance=self.settings.VASP_KPTS_TOLERANCE, - kspacing_tolerance=self.settings.VASP_KSPACING_TOLERANCE, - input_sets=self.settings.VASP_DEFAULT_INPUT_SETS, - LDAU_fields=self.settings.VASP_CHECKED_LDAU_FIELDS, - max_allowed_scf_gradient=self.settings.VASP_MAX_SCF_GRADIENT, - potcar_stats=self.potcar_stats, - ) - - if task_doc.tags: - bad_tags = list( - set(task_doc.tags).intersection(self.settings.DEPRECATED_TAGS) - ) - if len(bad_tags) > 0: - validation_doc.warnings.append( - f"Manual Deprecation by tags: {bad_tags}" - ) - validation_doc.valid = False - validation_doc.reasons.append(DeprecationMessage.MANUAL) - - return validation_doc + ) diff --git a/emmet-builders/pyproject.toml b/emmet-builders/pyproject.toml index 90fae6ceaf..1084eec30d 100644 --- a/emmet-builders/pyproject.toml +++ b/emmet-builders/pyproject.toml @@ -27,8 +27,7 @@ authors = [ ] license = { text = "Modified BSD" } dependencies = [ - "emmet-core[all]>=0.85", - "maggma>=0.57.6", + "emmet-core[all]>=0.86.1", "matminer>=0.9.1", "pymatgen-io-validation>=0.1.1", ] diff --git a/emmet-builders/tests/__init__.py b/emmet-builders/tests/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/emmet-builders/tests/conftest.py b/emmet-builders/tests/conftest.py deleted file mode 100644 index a6401fe77c..0000000000 --- a/emmet-builders/tests/conftest.py +++ /dev/null @@ -1,7 +0,0 @@ -from emmet.core.testing_utils import _get_test_files_dir -import pytest - - -@pytest.fixture(scope="session") -def test_dir(): - return _get_test_files_dir("emmet.builders") diff --git a/emmet-builders/tests/test_absorption.py b/emmet-builders/tests/test_absorption.py deleted file mode 100644 index c8421645ba..0000000000 --- a/emmet-builders/tests/test_absorption.py +++ /dev/null @@ -1,43 +0,0 @@ -from pathlib import Path - -import pytest -from maggma.stores import JSONStore, MemoryStore -from monty.serialization import dumpfn, loadfn - -from emmet.builders.materials.absorption_spectrum import AbsorptionBuilder -from emmet.builders.vasp.materials import MaterialsBuilder - - -@pytest.fixture(scope="session") -def tasks_store(test_dir): - return JSONStore(test_dir / "sample_absorptions.json.gz") - - -@pytest.fixture(scope="session") -def materials_store(tasks_store): - materials_store = MemoryStore(key="material_id") - builder = MaterialsBuilder(tasks=tasks_store, materials=materials_store) - builder.run() - return materials_store - - -@pytest.fixture -def absorption_store(): - return MemoryStore(key="material_id") - - -def test_absorption_builder(tasks_store, absorption_store, materials_store): - builder = AbsorptionBuilder( - tasks=tasks_store, absorption=absorption_store, materials=materials_store - ) - builder.run() - - assert absorption_store.count() == 1 - assert absorption_store.count({"deprecated": False}) == 1 - - -def test_serialization(tmpdir): - builder = AbsorptionBuilder(MemoryStore(), MemoryStore(), MemoryStore()) - - dumpfn(builder.as_dict(), Path(tmpdir) / "test.json") - loadfn(Path(tmpdir) / "test.json") diff --git a/emmet-builders/tests/test_basic_descriptors.py b/emmet-builders/tests/test_basic_descriptors.py deleted file mode 100644 index b5bd65f2bd..0000000000 --- a/emmet-builders/tests/test_basic_descriptors.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest -from maggma.stores import JSONStore, MemoryStore - -# from emmet.builders.materials.basic_descriptors import BasicDescriptorsBuilder -from emmet.builders.vasp.materials import MaterialsBuilder - - -@pytest.fixture(scope="session") -def tasks_store(test_dir): - return JSONStore(test_dir / "test_As2SO6_tasks.json.gz") - - -@pytest.fixture(scope="session") -def materials_store(tasks_store): - materials_store = MemoryStore(key="material_id") - builder = MaterialsBuilder(tasks=tasks_store, materials=materials_store) - builder.run() - return materials_store - - -# @pytest.mark.skip(reason="Waiting on matminer update") -# def test_basic_descriptions(materials_store): -# descriptors_store = MemoryStore() -# builder = BasicDescriptorsBuilder( -# materials=materials_store, descriptors=descriptors_store -# ) -# builder.run() -# -# print(descriptors_store.query_one({})) -# assert descriptors_store.count() == 1 diff --git a/emmet-builders/tests/test_chemenv.py b/emmet-builders/tests/test_chemenv.py deleted file mode 100644 index e06fb53dee..0000000000 --- a/emmet-builders/tests/test_chemenv.py +++ /dev/null @@ -1,40 +0,0 @@ -from emmet.core.base import EmmetMeta -import pytest -from maggma.stores import JSONStore, MemoryStore - -from emmet.builders.materials.chemenv import ChemEnvBuilder -from emmet.builders.materials.oxidation_states import OxidationStatesBuilder - - -@pytest.fixture(scope="session") -def fake_materials(test_dir): - entries = JSONStore(test_dir / "LiTiO2_batt.json.gz", key="entry_id") - entries.connect() - - materials_store = MemoryStore(key="material_id") - materials_store.connect() - - for doc in entries.query(): - builder_meta = EmmetMeta(license="BY-C").model_dump() - materials_store.update( - { - "material_id": doc["entry_id"], - "structure": doc["structure"], - "deprecated": False, - "builder_meta": builder_meta, - } - ) - return materials_store - - -def test_chemenvstore(fake_materials): - oxi_store = MemoryStore() - builder = OxidationStatesBuilder( - materials=fake_materials, oxidation_states=oxi_store - ) - builder.run() - chemenv_store = MemoryStore() - builder2 = ChemEnvBuilder(oxidation_states=oxi_store, chemenv=chemenv_store) - builder2.run() - assert chemenv_store.count() == 6 - assert all([isinstance(d["composition"], dict) for d in chemenv_store.query()]) diff --git a/emmet-builders/tests/test_corrected_entries_thermo.py b/emmet-builders/tests/test_corrected_entries_thermo.py deleted file mode 100644 index cafaa2bdbd..0000000000 --- a/emmet-builders/tests/test_corrected_entries_thermo.py +++ /dev/null @@ -1,78 +0,0 @@ -from pathlib import Path - -import pytest -from maggma.stores import JSONStore, MemoryStore -from monty.serialization import dumpfn, loadfn - -from emmet.core.mpid import AlphaID - -from emmet.builders.materials.corrected_entries import CorrectedEntriesBuilder -from emmet.builders.materials.thermo import ThermoBuilder -from emmet.builders.vasp.materials import MaterialsBuilder - - -@pytest.fixture(scope="session") -def tasks_store(test_dir): - return JSONStore(test_dir / "test_si_tasks.json.gz") - - -@pytest.fixture(scope="session") -def materials_store(tasks_store): - materials_store = MemoryStore(key="material_id") - builder = MaterialsBuilder(tasks=tasks_store, materials=materials_store) - builder.run() - return materials_store - - -@pytest.fixture(scope="session") -def corrected_entries_store(): - return MemoryStore(key="chemsys") - - -@pytest.fixture -def thermo_store(): - return MemoryStore(key="thermo_id") - - -@pytest.fixture -def phase_diagram_store(): - return MemoryStore(key="phase_diagram_id") - - -def test_corrected_entries_serialization(tmpdir): - builder = CorrectedEntriesBuilder(MemoryStore(), MemoryStore(), MemoryStore()) - - dumpfn(builder.as_dict(), Path(tmpdir) / "test.json") - loadfn(Path(tmpdir) / "test.json") - - -def test_thermo_builder( - corrected_entries_store, materials_store, thermo_store, phase_diagram_store -): - - ce_builder = CorrectedEntriesBuilder( - materials=materials_store, corrected_entries=corrected_entries_store - ) - ce_builder.run() - - assert corrected_entries_store.count() == 1 - assert corrected_entries_store.count({"chemsys": "Si"}) == 1 - - thermo_builder = ThermoBuilder( - thermo=thermo_store, - corrected_entries=corrected_entries_store, - phase_diagram=phase_diagram_store, - ) - thermo_builder.run() - - assert thermo_store.count() == 1 - assert thermo_store.count({"material_id": str(AlphaID("mp-149"))}) == 1 - - assert phase_diagram_store.count() == 1 - - -def test_thermo_serialization(tmpdir): - builder = ThermoBuilder(MemoryStore(), MemoryStore(), MemoryStore()) - - dumpfn(builder.as_dict(), Path(tmpdir) / "test.json") - loadfn(Path(tmpdir) / "test.json") diff --git a/emmet-builders/tests/test_dielectric.py b/emmet-builders/tests/test_dielectric.py deleted file mode 100644 index d5d5345f10..0000000000 --- a/emmet-builders/tests/test_dielectric.py +++ /dev/null @@ -1,43 +0,0 @@ -from pathlib import Path - -import pytest -from maggma.stores import JSONStore, MemoryStore -from monty.serialization import dumpfn, loadfn - -from emmet.builders.materials.dielectric import DielectricBuilder -from emmet.builders.vasp.materials import MaterialsBuilder - - -@pytest.fixture(scope="session") -def tasks_store(test_dir): - return JSONStore(test_dir / "test_si_tasks.json.gz") - - -@pytest.fixture(scope="session") -def materials_store(tasks_store): - materials_store = MemoryStore(key="material_id") - builder = MaterialsBuilder(tasks=tasks_store, materials=materials_store) - builder.run() - return materials_store - - -@pytest.fixture -def dielectric_store(): - return MemoryStore(key="material_id") - - -def test_dielectric_builder(tasks_store, dielectric_store, materials_store): - builder = DielectricBuilder( - tasks=tasks_store, dielectric=dielectric_store, materials=materials_store - ) - builder.run() - - assert dielectric_store.count() == 1 - assert dielectric_store.count({"deprecated": False}) == 1 - - -def test_serialization(tmpdir): - builder = DielectricBuilder(MemoryStore(), MemoryStore(), MemoryStore()) - - dumpfn(builder.as_dict(), Path(tmpdir) / "test.json") - loadfn(Path(tmpdir) / "test.json") diff --git a/emmet-builders/tests/test_elasticity.py b/emmet-builders/tests/test_elasticity.py deleted file mode 100644 index 0e0ac663ca..0000000000 --- a/emmet-builders/tests/test_elasticity.py +++ /dev/null @@ -1,43 +0,0 @@ -from pathlib import Path - -import pytest -from maggma.stores import JSONStore, MemoryStore -from monty.serialization import dumpfn, loadfn - -from emmet.builders.materials.elasticity import ElasticityBuilder -from emmet.builders.vasp.materials import MaterialsBuilder - - -@pytest.fixture(scope="session") -def tasks_store(test_dir): - return JSONStore(test_dir / "elasticity/SiC_tasks.json.gz") - - -@pytest.fixture(scope="session") -def materials_store(tasks_store): - materials_store = MemoryStore(key="material_id") - builder = MaterialsBuilder(tasks=tasks_store, materials=materials_store) - builder.run() - return materials_store - - -@pytest.fixture -def elasticity_store(): - return MemoryStore(key="material_id") - - -def test_elasticity_builder(tasks_store, materials_store, elasticity_store): - builder = ElasticityBuilder( - tasks=tasks_store, materials=materials_store, elasticity=elasticity_store - ) - builder.run() - - assert elasticity_store.count() == 6 - assert elasticity_store.count({"deprecated": False}) == 6 - - -def test_serialization(tmpdir): - builder = ElasticityBuilder(MemoryStore(), MemoryStore(), MemoryStore()) - - dumpfn(builder.as_dict(), Path(tmpdir) / "test.json") - loadfn(Path(tmpdir) / "test.json") diff --git a/emmet-builders/tests/test_electronic_structure.py b/emmet-builders/tests/test_electronic_structure.py deleted file mode 100644 index 0f08d65cc3..0000000000 --- a/emmet-builders/tests/test_electronic_structure.py +++ /dev/null @@ -1,66 +0,0 @@ -from pathlib import Path - -import pytest -from maggma.stores import JSONStore, MemoryStore -from monty.serialization import dumpfn, loadfn - -from emmet.builders.materials.electronic_structure import ElectronicStructureBuilder -from emmet.builders.vasp.materials import MaterialsBuilder - - -@pytest.fixture(scope="session") -def tasks_store(test_dir): - return JSONStore( - test_dir / "electronic_structure/es_task_docs.json.gz", key="task_id" - ) - - -@pytest.fixture(scope="session") -def materials_store(tasks_store): - materials_store = MemoryStore(key="material_id") - builder = MaterialsBuilder(tasks=tasks_store, materials=materials_store) - builder.run() - return materials_store - - -@pytest.fixture -def electronic_structure_store(): - return MemoryStore(key="material_id") - - -@pytest.fixture -def bandstructure_fs(test_dir): - return JSONStore( - test_dir / "electronic_structure/es_bs_objs.json.gz", key="task_id" - ) - - -@pytest.fixture -def dos_fs(test_dir): - return JSONStore( - test_dir / "electronic_structure/es_dos_objs.json.gz", key="task_id" - ) - - -def test_electronic_structure_builder( - tasks_store, materials_store, electronic_structure_store, bandstructure_fs, dos_fs -): - builder = ElectronicStructureBuilder( - tasks=tasks_store, - materials=materials_store, - electronic_structure=electronic_structure_store, - bandstructure_fs=bandstructure_fs, - dos_fs=dos_fs, - ) - - builder.run() - assert electronic_structure_store.count() == 3 - - -def test_serialization(tmpdir): - builder = ElectronicStructureBuilder( - MemoryStore(), MemoryStore(), MemoryStore(), MemoryStore(), MemoryStore() - ) - - dumpfn(builder.as_dict(), Path(tmpdir) / "test.json") - loadfn(Path(tmpdir) / "test.json") diff --git a/emmet-builders/tests/test_magnetism.py b/emmet-builders/tests/test_magnetism.py deleted file mode 100644 index 4f90555c71..0000000000 --- a/emmet-builders/tests/test_magnetism.py +++ /dev/null @@ -1,55 +0,0 @@ -from pathlib import Path - -import pytest -from maggma.stores import JSONStore, MemoryStore -from monty.serialization import dumpfn, loadfn - -from emmet.builders.materials.magnetism import MagneticBuilder -from emmet.builders.vasp.materials import MaterialsBuilder -from emmet.core.mpid import AlphaID - - -@pytest.fixture(scope="session") -def tasks_store(test_dir): - return JSONStore(test_dir / "magnetism/magnetism_task_docs.json.gz") - - -@pytest.fixture(scope="session") -def materials_store(tasks_store): - materials_store = MemoryStore(key="material_id") - builder = MaterialsBuilder(tasks=tasks_store, materials=materials_store) - builder.run() - return materials_store - - -@pytest.fixture -def magnetism_store(): - return MemoryStore(key="material_id") - - -def test_magnetism_builder(tasks_store, magnetism_store, materials_store): - builder = MagneticBuilder( - tasks=tasks_store, magnetism=magnetism_store, materials=materials_store - ) - builder.run() - - assert magnetism_store.count() == 4 - assert magnetism_store.count({"deprecated": False}) == 4 - - test_mpids = { - "mp-1289887": "AFM", - "mp-1369002": "FiM", - "mp-1791788": "NM", - "mp-1867075": "FM", - } - - for mpid, ordering in test_mpids.items(): - doc = magnetism_store.query_one({"material_id": str(AlphaID(mpid))}) - assert doc["ordering"] == ordering - - -def test_serialization(tmpdir): - builder = MagneticBuilder(MemoryStore(), MemoryStore(), MemoryStore()) - - dumpfn(builder.as_dict(), Path(tmpdir) / "test.json") - loadfn(Path(tmpdir) / "test.json") diff --git a/emmet-builders/tests/test_materials.py b/emmet-builders/tests/test_materials.py deleted file mode 100644 index 4e63d73cf5..0000000000 --- a/emmet-builders/tests/test_materials.py +++ /dev/null @@ -1,46 +0,0 @@ -from pathlib import Path -from emmet.builders.settings import EmmetBuildSettings - -import pytest -from maggma.stores import JSONStore, MemoryStore -from monty.serialization import dumpfn, loadfn - -from emmet.builders.vasp.materials import MaterialsBuilder -from emmet.builders.vasp.task_validator import TaskValidator - - -@pytest.fixture(scope="session") -def tasks_store(test_dir): - return JSONStore(test_dir / "test_si_tasks.json.gz") - - -@pytest.fixture(scope="session") -def validation_store(tasks_store): - settings = EmmetBuildSettings(VASP_VALIDATE_POTCAR_STATS=False) - validation_store = MemoryStore() - builder = TaskValidator( - tasks=tasks_store, task_validation=validation_store, settings=settings - ) - builder.run() - return validation_store - - -@pytest.fixture -def materials_store(): - return MemoryStore() - - -def test_materials_builder(tasks_store, validation_store, materials_store): - builder = MaterialsBuilder( - tasks=tasks_store, task_validation=validation_store, materials=materials_store - ) - builder.run() - assert materials_store.count() == 1 - assert materials_store.count({"deprecated": False}) == 1 - - -def test_serialization(tmpdir): - builder = MaterialsBuilder(MemoryStore(), MemoryStore(), MemoryStore()) - - dumpfn(builder.as_dict(), Path(tmpdir) / "test.json") - loadfn(Path(tmpdir) / "test.json") diff --git a/emmet-builders/tests/test_mobility.py b/emmet-builders/tests/test_mobility.py deleted file mode 100644 index e00958df6c..0000000000 --- a/emmet-builders/tests/test_mobility.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest -from maggma.stores import JSONStore, MemoryStore - -from emmet.builders.mobility.migration_graph import MigrationGraphBuilder - - -@pytest.fixture(scope="session") -def ie_store(test_dir): - return JSONStore( - test_dir / "mobility/builder_migration_graph_set.json.gz", key="battery_id" - ) - - -@pytest.fixture -def mg_store(): - return MemoryStore() - - -@pytest.mark.skip( - "Investigate later, modifying structure matcher tolerances has no effect" -) -def test_migration_graph_builder(ie_store, mg_store): - builder = MigrationGraphBuilder( - insertion_electrode=ie_store, migration_graph=mg_store - ) - builder.run() - assert mg_store.count() == 2 - assert mg_store.count({"state": "successful"}) == 2 - assert mg_store.count({"deprecated": False}) == 2 - d = builder.as_dict() - assert type(d) is dict diff --git a/emmet-builders/tests/test_oxidation.py b/emmet-builders/tests/test_oxidation.py deleted file mode 100644 index a3ad522f39..0000000000 --- a/emmet-builders/tests/test_oxidation.py +++ /dev/null @@ -1,37 +0,0 @@ -import pytest -from maggma.stores import JSONStore, MemoryStore - -from emmet.core.base import EmmetMeta -from emmet.builders.materials.oxidation_states import OxidationStatesBuilder - - -@pytest.fixture(scope="session") -def fake_materials(test_dir): - entries = JSONStore(test_dir / "LiTiO2_batt.json.gz", key="entry_id") - entries.connect() - - materials_store = MemoryStore(key="material_id") - materials_store.connect() - - for doc in entries.query(): - builder_meta = EmmetMeta(license="BY-C").model_dump() - materials_store.update( - { - "material_id": doc["entry_id"], - "structure": doc["structure"], - "deprecated": False, - "builder_meta": builder_meta, - }, - ) - return materials_store - - -def test_oxidation_store(fake_materials): - oxi_store = MemoryStore() - builder = OxidationStatesBuilder( - materials=fake_materials, oxidation_states=oxi_store - ) - builder.run() - - assert oxi_store.count() == 6 - assert all([isinstance(d["composition"], dict) for d in oxi_store.query()]) diff --git a/emmet-builders/tests/test_piezoelectric.py b/emmet-builders/tests/test_piezoelectric.py deleted file mode 100644 index 7bc6e29fe3..0000000000 --- a/emmet-builders/tests/test_piezoelectric.py +++ /dev/null @@ -1,44 +0,0 @@ -from pathlib import Path - -import pytest -from maggma.stores import JSONStore, MemoryStore -from monty.serialization import dumpfn, loadfn - -from emmet.builders.materials.piezoelectric import PiezoelectricBuilder -from emmet.builders.vasp.materials import MaterialsBuilder - - -@pytest.fixture(scope="session") -def tasks_store(test_dir): - return JSONStore(test_dir / "test_As2SO6_tasks.json.gz") - - -@pytest.fixture(scope="session") -def materials_store(tasks_store): - materials_store = MemoryStore(key="material_id") - builder = MaterialsBuilder(tasks=tasks_store, materials=materials_store) - builder.run() - return materials_store - - -@pytest.fixture -def piezoelectric_store(): - return MemoryStore(key="material_id") - - -def test_piezoelectric_builder(tasks_store, piezoelectric_store, materials_store): - - builder = PiezoelectricBuilder( - tasks=tasks_store, piezoelectric=piezoelectric_store, materials=materials_store - ) - builder.run() - - assert piezoelectric_store.count() == 1 - assert piezoelectric_store.count({"deprecated": False}) == 1 - - -def test_serialization(tmpdir): - builder = PiezoelectricBuilder(MemoryStore(), MemoryStore(), MemoryStore()) - - dumpfn(builder.as_dict(), Path(tmpdir) / "test.json") - loadfn(Path(tmpdir) / "test.json") diff --git a/emmet-builders/tests/test_similarity.py b/emmet-builders/tests/test_similarity.py deleted file mode 100644 index 3e3a90a3ff..0000000000 --- a/emmet-builders/tests/test_similarity.py +++ /dev/null @@ -1,43 +0,0 @@ -import pytest -from maggma.stores import JSONStore, MemoryStore - -# from emmet.builders.materials.basic_descriptors import BasicDescriptorsBuilder -from emmet.builders.vasp.materials import MaterialsBuilder -from emmet.builders.materials.similarity import StructureSimilarityBuilder - - -@pytest.fixture(scope="session") -def tasks_store(test_dir): - return JSONStore( - [test_dir / "test_si_tasks.json.gz", test_dir / "test_As2SO6_tasks.json.gz"] - ) - - -@pytest.fixture(scope="session") -def materials_store(tasks_store): - materials_store = MemoryStore(key="material_id") - builder = MaterialsBuilder(tasks=tasks_store, materials=materials_store) - builder.run() - return materials_store - - -# @pytest.fixture(scope="session") -# def descriptors_store(materials_store): -# descriptors_store = MemoryStore(key="task_id") -# builder = BasicDescriptorsBuilder( -# materials=materials_store, descriptors=descriptors_store -# ) -# builder.run() -# return descriptors_store - - -@pytest.mark.skip(reason="Waiting on matminer update") -def test_basic_descriptions(descriptors_store): - similarity_store = MemoryStore() - builder = StructureSimilarityBuilder( - structure_similarity=similarity_store, site_descriptors=descriptors_store - ) - builder.run() - - print(similarity_store.query_one({})) - assert similarity_store.count() == 1 diff --git a/emmet-builders/tests/test_summary.py b/emmet-builders/tests/test_summary.py deleted file mode 100644 index 22edc4831b..0000000000 --- a/emmet-builders/tests/test_summary.py +++ /dev/null @@ -1,193 +0,0 @@ -from pathlib import Path - -import pytest -from maggma.stores import JSONStore, MemoryStore -from monty.serialization import dumpfn, loadfn - -from emmet.builders.materials.summary import SummaryBuilder -from emmet.builders.vasp.materials import MaterialsBuilder - - -@pytest.fixture(scope="session") -def tasks_store(test_dir): - return JSONStore(test_dir / "test_si_tasks.json.gz") - - -@pytest.fixture(scope="session") -def materials(tasks_store): - materials_store = MemoryStore(key="material_id") - builder = MaterialsBuilder(tasks=tasks_store, materials=materials_store) - builder.run() - return materials_store - - -@pytest.fixture -def electronic_structure(): - return MemoryStore(key="material_id") - - -@pytest.fixture -def thermo(): - return MemoryStore(key="material_id") - - -@pytest.fixture -def grain_boundaries(): - return MemoryStore() - - -@pytest.fixture -def chemenv(): - return MemoryStore() - - -@pytest.fixture -def absorption(): - return MemoryStore() - - -@pytest.fixture -def magnetism(): - return MemoryStore() - - -@pytest.fixture -def elasticity(): - return MemoryStore() - - -@pytest.fixture -def dielectric(): - return MemoryStore() - - -@pytest.fixture -def piezoelectric(): - return MemoryStore() - - -@pytest.fixture -def phonon(): - return MemoryStore() - - -@pytest.fixture -def insertion_electrodes(): - return MemoryStore() - - -@pytest.fixture -def substrates(): - return MemoryStore() - - -@pytest.fixture -def oxi_states(): - return MemoryStore() - - -@pytest.fixture -def surfaces(): - return MemoryStore() - - -@pytest.fixture -def eos(): - return MemoryStore() - - -@pytest.fixture -def xas(): - return MemoryStore() - - -@pytest.fixture -def provenance(): - return MemoryStore() - - -@pytest.fixture -def charge_density_index(): - return MemoryStore() - - -@pytest.fixture -def summary(): - return MemoryStore(key="material_id") - - -def test_summary_builder( - materials, - thermo, - xas, - chemenv, - absorption, - grain_boundaries, - electronic_structure, - magnetism, - elasticity, - dielectric, - piezoelectric, - phonon, - insertion_electrodes, - substrates, - surfaces, - oxi_states, - eos, - provenance, - charge_density_index, - summary, -): - builder = SummaryBuilder( - materials=materials, - electronic_structure=electronic_structure, - thermo=thermo, - magnetism=magnetism, - chemenv=chemenv, - absorption=absorption, - dielectric=dielectric, - piezoelectric=piezoelectric, - phonon=phonon, - insertion_electrodes=insertion_electrodes, - elasticity=elasticity, - substrates=substrates, - surfaces=surfaces, - oxi_states=oxi_states, - xas=xas, - grain_boundaries=grain_boundaries, - eos=eos, - provenance=provenance, - charge_density_index=charge_density_index, - summary=summary, - ) - - builder.run() - assert summary.count() == 1 - - -def test_serialization(tmpdir): - builder = SummaryBuilder( - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - MemoryStore(), - ) - - dumpfn(builder.as_dict(), Path(tmpdir) / "test.json") - loadfn(Path(tmpdir) / "test.json") diff --git a/emmet-builders/tests/test_utils.py b/emmet-builders/tests/test_utils.py deleted file mode 100644 index 4e26e43413..0000000000 --- a/emmet-builders/tests/test_utils.py +++ /dev/null @@ -1,105 +0,0 @@ -from emmet.builders.utils import ( - chemsys_permutations, - maximal_spanning_non_intersecting_subsets, - get_hop_cutoff, - get_potcar_stats, -) -from pymatgen.analysis.diffusion.neb.full_path_mapper import MigrationGraph -from numpy.testing import assert_almost_equal -from monty.serialization import loadfn, dumpfn -from emmet.core.settings import EmmetSettings - -import pytest - - -def test_maximal_spanning_non_intersecting_subsets(): - assert maximal_spanning_non_intersecting_subsets([{"A"}, {"A", "B"}]) == { - frozenset(d) for d in [{"A"}, {"B"}] - } - - assert maximal_spanning_non_intersecting_subsets([{"A", "B"}, {"A", "B", "C"}]) == { - frozenset(d) for d in [{"A", "B"}, {"C"}] - } - - assert maximal_spanning_non_intersecting_subsets( - [{"A", "B"}, {"A", "B", "C"}, {"D"}] - ) == {frozenset(d) for d in [{"A", "B"}, {"C"}, {"D"}]} - - -def test_chemsys_permutations(test_dir): - assert len(chemsys_permutations("Sr")) == 1 - assert len(chemsys_permutations("Sr-Hf")) == 3 - assert len(chemsys_permutations("Sr-Hf-O")) == 7 - - -def test_get_hop_cutoff(test_dir): - spinel_mg = loadfn(test_dir / "mobility/migration_graph_spinel_MgMn2O4.json.gz") - nasicon_mg = loadfn( - test_dir / "mobility/migration_graph_nasicon_MgV2(PO4)3.json.gz" - ) - - # tests for "min_distance" algorithm - assert_almost_equal( - get_hop_cutoff(spinel_mg.structure, "Mg", algorithm="min_distance"), - 1.95, - decimal=2, - ) - assert_almost_equal( - get_hop_cutoff(nasicon_mg.structure, "Mg", algorithm="min_distance"), - 3.80, - decimal=2, - ) - - # test for "hops_based" algorithm, terminated by number of unique hops condition - d = get_hop_cutoff(spinel_mg.structure, "Mg", algorithm="hops_based") - check_mg = MigrationGraph.with_distance(spinel_mg.structure, "Mg", d) - assert_almost_equal(d, 4.18, decimal=2) - assert len(check_mg.unique_hops) == 5 - - # test for "hops_based" algorithm, terminated by the largest hop length condition - d = get_hop_cutoff(nasicon_mg.structure, "Mg", algorithm="hops_based") - check_mg = MigrationGraph.with_distance(nasicon_mg.structure, "Mg", d) - assert_almost_equal(d, 4.59, decimal=2) - assert len(check_mg.unique_hops) == 6 - - -@pytest.mark.parametrize("method", ("potcar", "pymatgen", "stored")) -def test_get_potcar_stats(method: str, tmp_path): - calc_type = EmmetSettings().VASP_DEFAULT_INPUT_SETS - - try: - potcar_stats = get_potcar_stats(method=method) - except Exception as exc: - if any( - exc_str in str(exc) for exc_str in ("Set PMG_VASP_PSP_DIR", "No POTCAR for") - ): - # No Potcar library available, skip test - return - else: - raise exc - - # ensure that all calc types are included in potcar_stats - assert potcar_stats.keys() == calc_type.keys() - - for calc_type in potcar_stats: - # ensure that each entry has needed fields for both - # legacy and modern potcar validation - assert all( - [ - set(entry) == set(["hash", "keywords", "titel", "stats"]) - for entry in entries - ] - for entries in potcar_stats[calc_type].values() - ) - - if method == "stored": - new_stats_path = tmp_path / "_temp_potcar_stats.json" - dumpfn(potcar_stats, new_stats_path) - - new_potcar_stats = get_potcar_stats( - method="stored", path_to_stored_stats=new_stats_path - ) - assert all( - potcar_stats[calc_type] == new_potcar_stats[calc_type] - for calc_type in potcar_stats - ) diff --git a/emmet-builders/tests/test_vasp.py b/emmet-builders/tests/test_vasp.py deleted file mode 100644 index d3c99c9821..0000000000 --- a/emmet-builders/tests/test_vasp.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest -from maggma.stores import JSONStore, MemoryStore - -from emmet.builders.settings import EmmetBuildSettings -from emmet.builders.vasp.task_validator import TaskValidator - -intermediate_stores = ["validation"] - - -@pytest.fixture(scope="session") -def tasks_store(test_dir): - return JSONStore(test_dir / "test_si_tasks.json.gz") - - -@pytest.fixture(scope="session") -def validation_store(): - return MemoryStore() - - -def test_validator(tasks_store, validation_store): - settings = EmmetBuildSettings(VASP_VALIDATE_POTCAR_STATS=False) - builder = TaskValidator( - tasks=tasks_store, task_validation=validation_store, settings=settings - ) - builder.run() - assert validation_store.count() == tasks_store.count() - assert validation_store.count({"valid": True}) == tasks_store.count() - assert all( - list(d["run_type"]["value"] == "GGA" for d in list(validation_store.query())) - ) diff --git a/emmet-core/emmet/core/absorption.py b/emmet-core/emmet/core/absorption.py index 40178ab67a..ec30b59b44 100644 --- a/emmet-core/emmet/core/absorption.py +++ b/emmet-core/emmet/core/absorption.py @@ -18,8 +18,6 @@ class AbsorptionDoc(PropertyDoc): property_name: str = "Optical absorption spectrum" - task_id: str = Field(..., description="Calculation id") - energies: list[float] = Field( ..., description="Absorption energy in eV starting from 0" ) @@ -43,7 +41,7 @@ class AbsorptionDoc(PropertyDoc): bandgap: float | None = Field(None, description="The electronic band gap") - nkpoints: float | None = Field( + nkpoints: int | None = Field( None, description="The number of kpoints used in the calculation" ) @@ -57,7 +55,6 @@ def _convert_list_to_tensor(cls, l): def from_structure( cls, energies: list, - task_id: str, real_d: list[np.ndarray], imag_d: list[np.ndarray], absorption_co: list, @@ -87,7 +84,6 @@ def from_structure( "average_real_dielectric": real_d_average, "bandgap": bandgap, "nkpoints": nkpoints, - "task_id": task_id, }, **kwargs, ) diff --git a/emmet-core/emmet/core/arrow.py b/emmet-core/emmet/core/arrow.py index 60b0d7f0ff..74d5874e3e 100644 --- a/emmet-core/emmet/core/arrow.py +++ b/emmet-core/emmet/core/arrow.py @@ -1,7 +1,6 @@ import sys import types import typing -from typing_extensions import NotRequired from collections.abc import Iterable, Mapping from datetime import datetime from enum import Enum @@ -13,6 +12,7 @@ from monty.json import MSONable from pydantic._internal._model_construction import ModelMetaclass from pydantic.types import ImportString +from typing_extensions import NotRequired RED = "\033[31m" BLUE = "\033[34m" @@ -23,7 +23,7 @@ float: pa.float64(), str: pa.string(), bool: pa.bool_(), - datetime: pa.timestamp("us"), + datetime: pa.timestamp("us", tz="UTC"), } diff --git a/emmet-core/emmet/core/band_theory.py b/emmet-core/emmet/core/band_theory.py index 50a7f5a032..38bb5f640e 100644 --- a/emmet-core/emmet/core/band_theory.py +++ b/emmet-core/emmet/core/band_theory.py @@ -22,6 +22,7 @@ from emmet.core.math import Matrix3D, Vector3D from emmet.core.settings import EmmetSettings from emmet.core.types.pymatgen_types.structure_adapter import StructureType +from emmet.core.vasp.calc_types import RunType if TYPE_CHECKING: from collections.abc import Callable, Generator, Sequence @@ -41,6 +42,9 @@ class BandTheoryBase(BaseModel): structure: StructureType | None = Field( None, description="The structure associated with this calculation." ) + run_type: RunType | None = Field( + None, description="The functional used in the calculation." + ) def _deser_lattice(lattice: Lattice | dict | Matrix3D) -> Matrix3D: @@ -129,7 +133,7 @@ def to_pmg_like(self) -> dict[Spin, np.ndarray]: class ElectronicBS(BandStructure): """Define an electronic band structure schema.""" - path_convention: str | None = Field( + path_convention: BSPathType | None = Field( None, description="High symmetry path convention of the band structure" ) @@ -181,7 +185,7 @@ def from_pmg(cls, ebs: PmgBandStructure, **kwargs) -> Self: ) ) except Exception: - bs_type = None + bs_type = BSPathType.unknown config = { "qpoints": [qpt.frac_coords for qpt in ebs.kpoints], diff --git a/emmet-core/emmet/core/connectors/__init__.py b/emmet-core/emmet/core/connectors/__init__.py new file mode 100644 index 0000000000..7c0e4c6146 --- /dev/null +++ b/emmet-core/emmet/core/connectors/__init__.py @@ -0,0 +1 @@ +"""Aggregate resources for making external queries to databases.""" diff --git a/emmet-core/emmet/core/connectors/analysis.py b/emmet-core/emmet/core/connectors/analysis.py new file mode 100644 index 0000000000..c48f1e3600 --- /dev/null +++ b/emmet-core/emmet/core/connectors/analysis.py @@ -0,0 +1,246 @@ +"""Tools for processing database CIFs.""" + +from __future__ import annotations + +from contextlib import redirect_stderr, redirect_stdout, nullcontext, contextmanager +from io import StringIO +from tempfile import NamedTemporaryFile +from typing import TYPE_CHECKING + +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +from pymatgen.core import Structure +from pymatgen.io.cif import CifParser, CifBlock + +from emmet.core.settings import EmmetSettings + +try: + from pycodcif import parse as cod_tools_parse_cif +except ImportError: + cod_tools_parse_cif = None + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + +EMMET_SETTINGS = EmmetSettings() + + +@contextmanager +def _get_context_manager(verbose: bool) -> AbstractContextManager: + if verbose: + yield nullcontext() + else: + with redirect_stderr(StringIO()) as stderr, redirect_stdout( + StringIO() + ) as stdout: + yield stderr, stdout + + +def parse_cif_cod_tools( + cif_str: str, + verbose: bool = False, + cif_parser: CifParser | None = None, +) -> tuple[list[Structure], list[str]]: + """Parse a CIF with the COD tools parser. + + Parameters + ----------- + cif_str : str + The CIF string to parse + verbose : bool = False + Whether to pass error messages from pymatgen and CIF parsing tools. + Defaults to suppressing these messages. + cif_parser : pymatgen.io.cif.CifParser or None (default) + Existing instance of a CifParser to use + + Returns + ----------- + List of Structure if parsing is successful + List of str documenting any parsing issues + """ + + structures: list[Structure] = [] + remarks: list[str] = [] + + temp_file = NamedTemporaryFile(suffix=".cif") + cif_data = [] + with open(temp_file.name, "w", encoding="utf-8") as f: + # remove non-ASCII characters + f.write(cif_str.encode("ascii", "ignore").decode("ascii")) + f.seek(0) + try: + cif_data, _, _ = cod_tools_parse_cif(temp_file.name) + except Exception as exc: + remarks += [f"pycodcif.parse: {exc}"] + + temp_file.close() + + if cif_data: + + try: + cif_parser = cif_parser or CifParser.from_str(cif_str) + with _get_context_manager(verbose): + structures += [ + cif_parser._get_structure( + CifBlock(block["values"], block["loops"], block["name"]), + primitive=True, + symmetrized=False, + check_occu=True, + ) + for block in cif_data + ] + except Exception as exc: + remarks += [f"pycodcif/pymatgen: {exc}"] + + return structures, remarks + + +def remove_artificial_disorder( + structures: list[Structure], in_place: bool = True +) -> list[Structure]: + """Remove artificial disorder from a structure. + + Some of the ICSD CIFs are disordered in oxidation states only. + Because these are assigned by hand and don't reflect + actual chemical or configurational disorder, we + remove this artificial disorder here. + + Parameters + ----------- + structures : list of Structure + in_place : bool = True + Whether to modify `Structure`s in place + + Returns + ----------- + list of Structure + """ + output_structs = structures if in_place else [None] * len(structures) + for idx, structure in enumerate(structures): + if ( + not structure.is_ordered + and ( + non_oxi_struct := Structure( + structure.lattice, + species=[site.species for site in structure], + coords=structure.frac_coords, + coords_are_cartesian=False, + charge=structure.charge, + ).remove_oxidation_states() + ).is_ordered + ): + output_structs[idx] = non_oxi_struct + return output_structs + + +def remove_structures_with_fictive_elements( + structures: list[Structure], +) -> list[Structure]: + """Remove structures with fictive elements. + + This is used to ensure that a Structure contains only real elements. + + Sometimes, ICSD structures will use fictive elements to represent, + e.g., cation substitution. Without a list of substituents, + this is not useful for atomistic modelling. + + Parameters + ----------- + structures : list of Structure + + Returns + ----------- + list of Structure + """ + output_structs = [] + for structure in structures: + try: + _ = structure.composition.remove_charges().as_dict() + output_structs.append(structure) + except Exception: + continue + return output_structs + + +def remove_structures_with_unphysical_symmetry( + structures: list[Structure], +) -> list[Structure]: + """Remove structures whose symmetry cannot be determined. + + Sometimes the distances between atoms in a CIF is + unphysically small, or some other issue prevents symmetry + determination of a CIF. + + Parameters + ----------- + structures : list of Structure + + Returns + ----------- + list of Structure + """ + output_structures = [] + for structure in structures: + try: + sga = SpacegroupAnalyzer( + structure, + symprec=EMMET_SETTINGS.SYMPREC, + angle_tolerance=EMMET_SETTINGS.ANGLE_TOL, + ) + _ = sga.get_space_group_number() + output_structures.append(structure) + except Exception: + continue + return output_structures + + +def parse_cif(cif_str: str, verbose: bool = False) -> tuple[list[Structure], list[str]]: + """Parse a CIF string and apply sanity checks. + + Parameters + ----------- + cif_str : str + The CIF string to parse + verbose : bool = False + Whether to pass error messages from pymatgen and CIF parsing tools. + Defaults to suppressing these messages. + + Returns + ----------- + List of Structure if parsing is successful + List of str documenting any parsing issues + """ + + structures: list[Structure] = [] + remarks: list[str] = [] + + cif_parser = CifParser.from_str(cif_str, check_cif=False) + # Step 1: Try to parse with pymatgen without any changes to the CIF + try: + + with _get_context_manager(verbose): + structures = cif_parser.parse_structures(primitive=True) + + except Exception as exc: + remarks.append(f"pymatgen.io.cif.CifParser: {exc}") + + # Step 2 (Optional): Use the Crystallography Open Database CIF parser + # to correct errors in the CIF if the structures could not be parsed. + if not structures and cod_tools_parse_cif: + structures, new_remarks = parse_cif_cod_tools( + cif_str, verbose=verbose, cif_parser=cif_parser + ) + remarks.extend(new_remarks) + + # Step 3: Remove structures with fictive elements + structures = remove_structures_with_fictive_elements(structures) + + # Step 4: Remove structures whose symmetry cannot be determined + structures = remove_structures_with_unphysical_symmetry(structures) + + # Step 5: Check remaining CIFs with pymatgen CIF checker + structures = [ + structure for structure in structures if (not cif_parser.check(structure)) + ] + + # Step 6: Remove artificial disorder in oxidation states + return remove_artificial_disorder(structures), remarks diff --git a/emmet-core/emmet/core/connectors/icsd/__init__.py b/emmet-core/emmet/core/connectors/icsd/__init__.py new file mode 100644 index 0000000000..6ca512768f --- /dev/null +++ b/emmet-core/emmet/core/connectors/icsd/__init__.py @@ -0,0 +1,5 @@ +"""Tools for querying the ICSD API programmatically.""" + +from emmet.core.connectors.icsd.client import IcsdClient + +__all__ = ["IcsdClient"] diff --git a/emmet-core/emmet/core/connectors/icsd/client.py b/emmet-core/emmet/core/connectors/icsd/client.py new file mode 100644 index 0000000000..79947acc72 --- /dev/null +++ b/emmet-core/emmet/core/connectors/icsd/client.py @@ -0,0 +1,343 @@ +"""Retrieve CIF and metadata from the ICSD API. + +This module is based on + https://github.com/lrcfmd/ICSDClient/ +""" + +from __future__ import annotations + +import os +import re +import requests +from requests.adapters import HTTPAdapter +from time import time +from typing import TYPE_CHECKING +from urllib3.util.retry import Retry + +import logging + +import multiprocessing +import numpy as np +from pydantic import BaseModel, Field, PrivateAttr + +from emmet.core.connectors.icsd.settings import IcsdClientSettings +from emmet.core.connectors.icsd.enums import ( + IcsdAdvancedSearchKeys, + IcsdSubset, + IcsdDataFields, +) +from emmet.core.connectors.icsd.schemas import IcsdPropertyDoc + +if TYPE_CHECKING: + from typing import Any + +SETTINGS = IcsdClientSettings() + +# ICSD tokens expire in one hour +_ICSD_TOKEN_TIMEOUT = 3600 + +logger = logging.getLogger("emmet-core") + + +class IcsdClient(BaseModel): + """Query data via the ICSD API.""" + + username: str = Field(SETTINGS.USERNAME) + password: str = Field(SETTINGS.PASSWORD) + + max_retries: float | None = Field(SETTINGS.MAX_RETRIES) + timeout: float | None = Field(SETTINGS.TIMEOUT) + max_batch_size: float | None = Field(SETTINGS.MAX_BATCH_SIZE) + + use_document_model: bool = Field(True) + num_parallel_requests: int | None = Field(None) + + _auth_token: str | None = PrivateAttr(None) + _session_start_time: float | None = PrivateAttr(None) + _session: requests.Session | None = PrivateAttr(None) + + @property + def _is_windows(self) -> bool: + return os.name == "nt" + + def refresh_session(self, force: bool = False) -> None: + if self._session_start_time is None: + self._session_start_time = time() + + if ( + self._auth_token is None + or ((time() - self._session_start_time) > 0.98 * _ICSD_TOKEN_TIMEOUT) + or force + ): + if self._session: + self.logout() + self._session_start_time = time() + self.login() + + def login(self) -> None: + + response = requests.post( + "https://icsd.fiz-karlsruhe.de/ws/auth/login", + headers={ + "accept": "text/plain", + "Content-Type": "application/x-www-form-urlencoded", + }, + data={ + "loginid": self.username, + "password": self.password, + }, + ) + if response.status_code == 200: + self._auth_token = response.headers["ICSD-Auth-Token"] + if self._auth_token is None: + logger.warning( + f"{self.__module__}.{self.__class__.__name__} " + f"failed to fetch auth token: {response.content}" + ) + else: + logger.warning( + f"{self.__module__}.{self.__class__.__name__} " + "failed to fetch auth token with status code " + f"{response.status_code}: {response.content}" + ) + + self._session = requests.Session() + self._session.headers = {"ICSD-Auth-Token": self._auth_token} + retry = Retry( + total=self.max_retries, + read=self.max_retries, + connect=self.max_retries, + respect_retry_after_header=True, + status_forcelist=[429, 504, 502], # rate limiting + backoff_factor=0.1, + ) + adapter = HTTPAdapter(max_retries=retry) + self._session.mount("http://", adapter) + self._session.mount("https://", adapter) + + def logout(self) -> None: + + if not self._session: + return + + _ = self._session.get( + "https://icsd.fiz-karlsruhe.de/ws/auth/logout", + headers={ + "accept": "text/plain", + }, + params=[("windowsclient", self._is_windows)], + ) + self._auth_token = None + self._session_start_time = None + self._session.close() + self._session = None + + def __enter__(self) -> None: + self.login() + return self + + def __exit__(self, *args) -> None: + self.logout() + + def __del__(self) -> None: + self.logout() + + def _get(self, *args, **kwargs) -> requests.Response: + self.refresh_session() + params = tuple( + list(kwargs.pop("params", [])) + [("windowsclient", self._is_windows)] + ) + resp = self._session.get(*args, **kwargs, params=params) + if resp.status_code != 200: + logger.warning( + f"{self.__module__}.{self.__class__.__name__} " + "failed to fetch content with status code " + f"{resp.status_code}: {resp.content}" + ) + return resp + + def _get_cifs(self, collection_codes: int | list[int]) -> dict[int, str]: + if isinstance(collection_codes, int) or len(collection_codes) == 1: + cif_str = self._get( + f"https://icsd.fiz-karlsruhe.de/ws/cif/{collection_codes[0]}", + headers={ + "accept": "application/cif", + }, + ).content.decode() + else: + cif_str = self._get( + "https://icsd.fiz-karlsruhe.de/ws/cif/multiple", + headers={ + "accept": "application/cif", + }, + params=[("idnum", collection_codes)], + ).content.decode() + + return { + int(re.search(r"_database_code_ICSD ([0-9]+)", cif_body).group(1)): "#(C)" + + cif_body + for cif_body in cif_str.split("\n#(C)")[1:] + } + + def _search( + self, + indices: list[int], + properties: list[str | IcsdDataFields] | None = None, + include_cif: bool = False, + include_metadata: bool = False, + _data: list | None = None, + ) -> list[dict[str, Any]]: + + self.refresh_session(force=True) + search_props = [ + ( + prop.value + if isinstance(prop, IcsdDataFields) + else IcsdDataFields(prop).value + ) + for prop in (properties or list(IcsdDataFields)) + ] + + if len(indices) > self.max_batch_size: + batched_ids: list[list[str]] = [ + v.tolist() + for v in np.array_split( + indices, np.ceil(len(indices) / self.max_batch_size) + ) + ] + + data = [] + for i, batch in enumerate(batched_ids): + data.extend( + self._search( + batch, + properties=search_props, + include_cif=include_cif, + include_metadata=include_metadata, + _data=_data, + ) + ) + return data + + if not include_cif and not include_metadata: + return [{"icsd_internal_id": int(idx) for idx in indices}] + + if include_metadata: + if "CollectionCode" not in search_props: + search_props.append("CollectionCode") + + response = self._get( + "https://icsd.fiz-karlsruhe.de/ws/csv", + headers={ + "accept": "application/csv", + }, + params=( + ("idnum", tuple(indices)), + ("listSelection", search_props), + ), + ) + + data = [] + if response.status_code == 200: + csv_data = [ + row.split("\t") for row in response.content.decode().splitlines() + ] + columns = csv_data[0][:-1] + + data += [ + {IcsdDataFields[k].value: row[i] for i, k in enumerate(columns)} + for row in csv_data[1:] + ] + else: + logger.warning( + f"{self.__module__}.{self.__class__.__name__} " + "csv search failed with status code " + f"{response.status_code}: {response.content}" + ) + + if include_cif: + cifs = self._get_cifs(indices) + if include_metadata: + for i, doc in enumerate(data): + data[i]["cif"] = cifs.get(int(doc["collection_code"])) + else: + data = [{"collection_code": cc, "cif": cif} for cc, cif in cifs.items()] + + if _data: + _data.extend(data) + return data + + def search( + self, + subset: IcsdSubset | str | None = None, + properties: list[str | IcsdDataFields] | None = None, + include_cif: bool = False, + include_metadata: bool = False, + **kwargs, + ) -> list: + + query_vars = [] + for k in IcsdAdvancedSearchKeys: + if (v := kwargs.get(k.value)) is not None or ( + v := kwargs.get(k.name) is not None + ): + if isinstance(v, tuple): + v = f"{v[0]}-{v[1]}" + elif isinstance(v, list): + v = ",".join(v) + query_vars.append(f"{k.name.lower()} : {v}") + query_str = " and ".join(query_vars) + + params = [("query", query_str)] + if subset: + params.append(("content type", IcsdSubset(subset).name)) + + response = self._get( + "https://icsd.fiz-karlsruhe.de/ws/search/expert", + headers={ + "accept": "application/xml", + }, + params=params, + ) + + idxs: list[str] = [] + if matches := re.match(".*(.*).*", response.content.decode()): + idxs.extend(list(matches.groups())[0].split()) + + if self.num_parallel_requests and len(idxs) > self.num_parallel_requests: + batched_idxs = np.array_split(idxs, self.num_parallel_requests) + + manager = multiprocessing.Manager() + procs = [] + res = manager.list() + for iproc in range(self.num_parallel_requests): + proc = multiprocessing.Process( + target=self._search, + args=(batched_idxs[iproc].tolist(),), + kwargs={ + "properties": properties, + "include_cif": include_cif, + "include_metadata": include_metadata, + "_data": res, + }, + ) + proc.start() + procs.append(proc) + + for proc in procs: + proc.join() + return list(res) + + data = self._search( + idxs, + properties=properties, + include_cif=include_cif, + include_metadata=include_metadata, + ) + if subset: + for i in range(len(data)): + data[i]["subset"] = subset + + if self.use_document_model: + data = [IcsdPropertyDoc(**props) for props in data] + return data diff --git a/emmet-core/emmet/core/connectors/icsd/enums.py b/emmet-core/emmet/core/connectors/icsd/enums.py new file mode 100644 index 0000000000..7d9b5aa3a9 --- /dev/null +++ b/emmet-core/emmet/core/connectors/icsd/enums.py @@ -0,0 +1,105 @@ +"""Define ICSD-specific enums.""" + +from enum import Enum + + +class IcsdSubset(Enum): + EXPERIMENTAL_INORGANIC = "experimental_inorganic" + EXPERIMENTAL_METALORGANIC = "experimental_metalorganic" + THERORETICAL_STRUCTURES = "theoretical" + + +class IcsdAdvancedSearchKeys(Enum): + + AUTHORS = "authors" + ARTICLE = "article" + DOI = "doi" + PUBLICATIONYEAR = "publication_year" + PAGEFIRST = "page_first" + JOURNAL = "journal" + VOLUME = "volume" + ABSTRACT = "abstract" + KEYWORDS = "keywords" + CELLVOLUME = "cell_volume" + CALCDENSITY = "calc_density" + CELLPARAMETERS = "cell_parameters" + SEARCH = "search" + STRUCTUREDFORMULA = "structured_formula" + CHEMICALNAME = "chemical_name" + MINERALNAME = "mineral_name" + MINERALNAMEIMA = "mineral_name_ima" + MINERALGROUP = "mineral_group" + ZVALUECHEMISTRY = "z_value_chemistry" + ANXFORMULA = "anx_formula" + ABFORMULA = "ab_formula" + FORMULAWEIGHT = "formula_weight" + NUMBEROFELEMENTS = "number_of_elements" + COMPOSITION = "composition" + COLLECTIONCODE = "collection_code" + PDFNUMBER = "pdf_number" + RELEASE = "release" + RECORDINGDATE = "recording_date" + MODIFICATIONDATE = "modification_date" + COMMENT = "comment" + RVALUE = "r_value" + TEMPERATURE = "temperature" + PRESSURE = "pressure" + SAMPLETYPE = "sample_type" + RADIATIONTYPE = "radiation_type" + STRUCTURETYPE = "structure_type" + SPACEGROUPSYMBOL = "space_group_symbol" + SPACEGROUPNUMBER = "space_group_number" + BRAVAISLATTICE = "bravais_lattice" + CRYSTALSYSTEM = "crystal_system" + CRYSTALCLASS = "crystal_class" + LAUECLASS = "laue_class" + WYCKOFFSEQUENCE = "wyckoff_sequence" + PEARSONSYMBOL = "pearson_symbol" + INVERSIONCENTER = "inversion_center" + POLARAXIS = "polaraxis" + + +class IcsdDataFields(Enum): + + CollectionCode = "collection_code" + CcdcNo = "ccdc_no" + HMS = "h_m_s" + StructuredFormula = "structured_formula" + StructureType = "structure_type" + Title = "title" + Authors = "authors" + Reference = "reference" + CellParameter = "cell_parameter" + ReducedCellParameter = "reduced_cell_parameter" + StandardisedCellParameter = "standardised_cell_parameter" + CellVolume = "cell_volume" + FormulaUnitsPerCell = "formula_units_per_cell" + FormulaWeight = "formula_weight" + Temperature = "temperature" + Pressure = "pressure" + RValue = "r_value" + SumFormula = "sum_formula" + ANXFormula = "a_n_x_formula" + ABFormula = "a_b_formula" + ChemicalName = "chemical_name" + MineralName = "mineral_name" + MineralNameIma = "mineral_name_ima" + MineralGroup = "mineral_group" + MineralSeries = "mineral_series" + MineralRootGroup = "mineral_root_group" + MineralSubGroup = "mineral_sub_group" + MineralSuperGroup = "mineral_super_group" + MineralSubClass = "mineral_sub_class" + MineralClass = "mineral_class" + CalculatedDensity = "calculated_density" + MeasuredDensity = "measured_density" + PearsonSymbol = "pearson_symbol" + WyckoffSequence = "wyckoff_sequence" + Journal = "journal" + Volume = "volume" + PublicationYear = "publication_year" + Page = "page" + Quality = "quality" + Keywords = "keywords" + Ccdc = "ccdc" + Pdf = "pdf" diff --git a/emmet-core/emmet/core/connectors/icsd/schemas.py b/emmet-core/emmet/core/connectors/icsd/schemas.py new file mode 100644 index 0000000000..64708d0c84 --- /dev/null +++ b/emmet-core/emmet/core/connectors/icsd/schemas.py @@ -0,0 +1,154 @@ +"""Define document models used by the client.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pydantic import BaseModel, Field, model_validator, ConfigDict +from uncertainties import ufloat_fromstr + +from emmet.core.connectors.analysis import parse_cif +from emmet.core.connectors.icsd.enums import IcsdSubset +from emmet.core.types.pymatgen_types.structure_adapter import StructureType + +if TYPE_CHECKING: + from typing import Any + + +class UFloat(BaseModel): + value: float | None = None + uncertainty: float | None = None + + @model_validator(mode="before") + @classmethod + def parse_uncert(cls, config: Any) -> Any: + if isinstance(config, str): + if "(" in config: + parsed = ufloat_fromstr(config) + config = {"value": parsed.n, "uncertainty": parsed.s} + else: + try: + config = {"value": float(config)} + except ValueError: + config = {} + return config + + +class CellParameters(BaseModel): + + a: UFloat | None = None + b: UFloat | None = None + c: UFloat | None = None + alpha: UFloat | None = None + beta: UFloat | None = None + gamma: UFloat | None = None + + @model_validator(mode="before") + @classmethod + def from_str(cls, config): + """Parse space-separated lattice parameters.""" + lps = ["a", "b", "c", "alpha", "beta", "gamma"] + if isinstance(config, str): + vals = config.split() + config = {lp: vals[i] for i, lp in enumerate(lps)} + return config + + +class IcsdPropertyDoc(BaseModel): + """General container for ICSD data.""" + + model_config = ConfigDict(use_enum_values=True) + + collection_code: int | None = Field( + None, description="The ICSD identifier of this entry." + ) + + icsd_internal_id: int | None = Field( + None, + description="The internal identifier for the ICSD, not the collection code / ICSD ID.", + ) + cif: str | None = Field( + None, description="The CIF file associated with this entry." + ) + subset: IcsdSubset | None = Field( + None, description="The subset of the ICSD to which this entry belongs." + ) + ccdc_no: int | None = Field(None) + ccdc: int | None = None + + h_m_s: str | None = None + pearson_symbol: str | None = None + wyckoff_sequence: str | None = None + + structured_formula: str | None = None + sum_formula: str | None = None + a_n_x_formula: str | None = None + a_b_formula: str | None = None + + structure_type: str | None = None + title: str | None = None + authors: list[str] | None = None + journal: str | None = None + publication_year: int | None = None + volume: int | None = None + page: str | None = None + reference: str | None = None + + cell_parameter: CellParameters | None = None + reduced_cell_parameter: CellParameters | None = None + standardised_cell_parameter: CellParameters | None = None + + cell_volume: UFloat | None = None + formula_units_per_cell: int | None = None + formula_weight: float | None = None + + temperature: float | None = None + pressure: float | None = None + r_value: float | None = None + + chemical_name: str | None = None + mineral_name: str | None = None + mineral_name_ima: str | None = None + mineral_group: str | None = None + mineral_series: str | None = None + mineral_root_group: str | None = None + mineral_sub_group: str | None = None + mineral_super_group: str | None = None + mineral_sub_class: str | None = None + mineral_class: str | None = None + + calculated_density: UFloat | None = None + measured_density: UFloat | None = None + + quality: int | None = None + keywords: str | None = None + + pdf: str | None = None + + structures: list[StructureType] | None = Field( + None, description="A list of validated `Structure`s parsed from the CIF." + ) + cif_parsing_errors: list[str] | None = Field( + None, description="A list of any errors encountered while parsing the CIF." + ) + + @model_validator(mode="before") + @classmethod + def deserialize(cls, config: Any) -> Any: + """Parse ICSD data into a structured format.""" + if isinstance(config.get("authors"), str): + config["authors"] = config["authors"].split(";") + + for k, v in config.items(): + if isinstance(v, str) and len(v) == 0: + config[k] = None + + if config.get("cif") and config.get("structures") is None: + config["structures"], config["cif_parsing_errors"] = parse_cif( + config["cif"] + ) + + if not config.get("cif_parsing_errors"): + config.pop("cif_parsing_errors") + + return config diff --git a/emmet-core/emmet/core/connectors/icsd/settings.py b/emmet-core/emmet/core/connectors/icsd/settings.py new file mode 100644 index 0000000000..1d97439031 --- /dev/null +++ b/emmet-core/emmet/core/connectors/icsd/settings.py @@ -0,0 +1,28 @@ +"""Define basic settings for the ICSD API client.""" + +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class IcsdClientSettings(BaseSettings): + + USERNAME: str | None = Field(None, description="ICSD username.") + PASSWORD: str | None = Field(None, description="ICSD password.") + + MAX_RETRIES: int | None = Field( + 10, description="The maximum number of retries when querying the ICSD API." + ) + + TIMEOUT: float | None = Field( + 15.0, description="The time in seconds to wait for a query to complete." + ) + + MAX_BATCH_SIZE: int | None = Field( + 500, + description=( + "The maximum number of structures to retrieve " + "during pagination of query results." + ), + ) + + model_config = SettingsConfigDict(env_prefix="ICSD_API_") diff --git a/emmet-core/emmet/core/electronic_structure.py b/emmet-core/emmet/core/electronic_structure.py index a4cb77a507..2c85aa24c6 100644 --- a/emmet-core/emmet/core/electronic_structure.py +++ b/emmet-core/emmet/core/electronic_structure.py @@ -12,28 +12,31 @@ CollinearMagneticStructureAnalyzer, Ordering, ) -from pymatgen.core import Structure from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine from pymatgen.electronic_structure.core import OrbitalType, Spin -from pymatgen.electronic_structure.dos import CompleteDos +from pymatgen.io.vasp.sets import MPStaticSet from pymatgen.symmetry.analyzer import SpacegroupAnalyzer -from pymatgen.symmetry.bandstructure import HighSymmKpath +from emmet.core.material import PropertyOrigin from emmet.core.material_property import PropertyDoc -from emmet.core.mpid import AlphaID from emmet.core.settings import EmmetSettings from emmet.core.types.enums import ValueEnum from emmet.core.types.pymatgen_types.bandstructure_symm_line_adapter import ( BandStructureSymmLineType, + TypedBandDict, ) from emmet.core.types.pymatgen_types.dos_adapter import CompleteDosType from emmet.core.types.pymatgen_types.element_adapter import ElementType -from emmet.core.types.typing import DateTimeType, IdentifierType, TypedBandDict +from emmet.core.types.typing import DateTimeType, IdentifierType if TYPE_CHECKING: from typing import Any + + from pymatgen.core import Structure from typing_extensions import Self + from emmet.core.types.electronic_structure import BSShim, DosShim + SETTINGS = EmmetSettings() OrderingType = Annotated[ @@ -47,6 +50,7 @@ class BSPathType(ValueEnum): setyawan_curtarolo = "setyawan_curtarolo" hinuma = "hinuma" latimer_munro = "latimer_munro" + unknown = "unknown" class DOSProjectionType(ValueEnum): @@ -96,12 +100,6 @@ class DOSObjectDoc(BaseModel): class ElectronicStructureBaseData(BaseModel): - task_id: IdentifierType = Field( - ..., - description="The source calculation (task) ID for the electronic structure data. " - "This has the same form as a Materials Project ID.", - ) - band_gap: float = Field(..., description="Band gap energy in eV.") cbm: float | None = Field(None, description="Conduction band minimum data.") vbm: float | None = Field(None, description="Valence band maximum data.") @@ -127,26 +125,13 @@ def _deser_cbm_vbm(band: Any) -> TypedBandDict: return band -def _deser_equiv_labels(equivalent_labels: Any): - """Validate band structure equivalent labels.""" - if isinstance(next(iter(equivalent_labels.values())), list): - equivalent_labels = { - convention: { - other_convention: {k: v for k, v in label_tuples} - for other_convention, label_tuples in other_mapping - } - for convention, other_mapping in equivalent_labels.items() - } - return equivalent_labels - - class BandStructureSummaryData(ElectronicStructureSummary): """Schematize high-level band structure data for the API.""" - equivalent_labels: Annotated[ - dict[str, dict[str, dict[str, str]]], BeforeValidator(_deser_equiv_labels) - ] = Field(..., description="Equivalent k-point labels in other k-path conventions.") - + task_id: IdentifierType | None = Field( + None, + description="The source calculation (task) ID that this band structure comes from.", + ) nbands: float = Field(..., description="Number of bands.") direct_gap: float = Field(..., description="Direct gap energy in eV.") cbm: Annotated[TypedBandDict | None, BeforeValidator(_deser_cbm_vbm)] | None = ( @@ -223,6 +208,10 @@ def _deser_orbital(orbital): class DosData(BaseModel): + task_id: IdentifierType | None = Field( + None, + description="The source calculation (task) ID that this density of states comes from.", + ) total: dict[SpinType, DosSummaryData] | None = Field( None, description="Total DOS summary data." ) @@ -269,84 +258,296 @@ class ElectronicStructureDoc(PropertyDoc, ElectronicStructureSummary): ) @classmethod - def from_bsdos( # type: ignore[override] + def from_bs( cls, - dos: dict[IdentifierType, CompleteDos], - is_gap_direct: bool, - is_metal: bool, - material_id: IdentifierType | None = None, - origins: list[dict] = [], - structures: dict[IdentifierType, Structure] | None = None, - setyawan_curtarolo: dict[IdentifierType, BandStructureSymmLine] | None = None, - hinuma: dict[IdentifierType, BandStructureSymmLine] | None = None, - latimer_munro: dict[IdentifierType, BandStructureSymmLine] | None = None, + bandstructures: BSShim, + origins: list[PropertyOrigin], + structures: dict[IdentifierType, Structure], **kwargs, ) -> Self: """ - Builds a electronic structure document using band structure and density of states data. + Builds an electronic structure document using band structure data. Args: + bandstructures (BSShim): Struct of bandstructures with identifiers. + origins (list[PropertyOrigin]): Optional origins information for final doc. + structures (dict[AlphaID or MPID, Structure]) = Dictionary mapping a calculation (task) ID to the + structures used as inputs. This is to ensures correct magnetic moment information is included. material_id (AlphaID or MPID): A material ID. - dos (dict[AlphaID or MPID, CompleteDos]): Dictionary mapping a calculation (task) ID to a CompleteDos object. - is_gap_direct (bool): Direct gap indicator included at root level of document. - is_metal (bool): Metallic indicator included at root level of document. - structures (dict[AlphaID or MPID, Structure]) = Dictionary mapping a calculation (task) ID to the structures used - as inputs. This is to ensures correct magnetic moment information is included. - setyawan_curtarolo (dict[AlphaID or MPID, BandStructureSymmLine]): Dictionary mapping a calculation (task) ID to a - BandStructureSymmLine object from a calculation run using the Setyawan-Curtarolo k-path convention. - hinuma (dict[AlphaID or MPID, BandStructureSymmLine]): Dictionary mapping a calculation (task) ID to a - BandStructureSymmLine object from a calculation run using the Hinuma et al. k-path convention. - latimer_munro (dict[AlphaID or MPID, BandStructureSymmLine]): Dictionary mapping a calculation (task) ID to a - BandStructureSymmLine object from a calculation run using the Latimer-Munro k-path convention. - origins (list[dict]): Optional origins information for final doc """ + bs_data = _generate_bs_data(bandstructures, origins, structures) + origins = [origin for origin in origins] + [bs_data["es_origins_from_bs"]] + + return bs_checks( + cls.from_structure( + band_gap=bs_data["band_gap"], + cbm=bs_data["cbm"], + vbm=bs_data["vbm"], + efermi=bs_data["efermi"], + is_gap_direct=bs_data["is_gap_direct"], + is_metal=bs_data["is_metal"], + magnetic_ordering=bs_data["bs_magnetic_ordering"], + bandstructure=bs_data["bandstructure"], + origins=origins, + **kwargs, + ), + structures, + bandstructures, + ) - # -- Process density of states data + @classmethod + def from_dos( + cls, + dos: DosShim, + is_gap_direct: bool, + origins: list[PropertyOrigin], + structures: dict[IdentifierType, Structure], + **kwargs, + ) -> Self: + """ + Builds an electronic structure document using density of states data. - dos_task, dos_obj = list(dos.items())[0] + Args: + dos (DosShim): Struct with a CompleteDos and identifier. + is_gap_direct (bool): Direct gap indicator included at root level of document, result of VASP outputs. + origins (list[PropertyOrigin]): Origins information for final doc. + structures (dict[AlphaID or MPID, Structure]) = Dictionary mapping a calculation (task) ID to the + structures used as inputs. This is to ensures correct magnetic moment information is included. + material_id (AlphaID or MPID): A material ID. - orbitals = [OrbitalType.s, OrbitalType.p, OrbitalType.d] - spins = list(dos_obj.densities.keys()) + """ + dos_data = _generate_dos_data(dos, origins, structures) + origins = [origin for origin in origins] + [dos_data["es_origins_from_dos"]] + + return dos_checks( + cls.from_structure( + band_gap=dos_data["band_gap"], + cbm=dos_data["cbm"], + vbm=dos_data["vbm"], + efermi=dos_data["efermi"], + is_gap_direct=is_gap_direct, + is_metal=dos_data["is_metal"], + magnetic_ordering=dos_data["dos_magnetic_ordering"], + dos=dos_data["dos_entry"], + origins=origins, + **kwargs, + ), + structures, + dos, + ) - ele_dos = dos_obj.get_element_dos() - tot_orb_dos = dos_obj.get_spd_dos() + @classmethod + def from_bsdos( + cls, + bandstructures: BSShim, + dos: DosShim, + origins: list[PropertyOrigin], + structures: dict[IdentifierType, Structure], + **kwargs, + ) -> Self: + """ + Builds an electronic structure document using band structure and density of states data. - elements = ele_dos.keys() + Args: + bandstructures (BSShim): Struct of bandstructures with identifiers. + dos (DosShim): Struct with a CompleteDos and identifier. + origins (list[PropertyOrigin]): Origins information for final doc. + structures (dict[AlphaID or MPID, Structure]) = Dictionary mapping a calculation (task) ID to the + structures used as inputs. This is to ensures correct magnetic moment information is included. + material_id (AlphaID or MPID): A material ID. - dos_efermi = dos_obj.efermi + """ + bs_data = _generate_bs_data(bandstructures, origins, structures) + dos_data = _generate_dos_data(dos, origins, structures) + + # TODO: add ability to add blessed structure from material into ranking + # for es origins, i.e., r2SCAN static/relax > GGA NSCF line > ... + origins = [origin for origin in origins] + [bs_data["es_origins_from_bs"]] + magnetic_ordering = bs_data["bs_magnetic_ordering"] + + return bsdos_checks( + cls.from_structure( + band_gap=bs_data["band_gap"], + cbm=bs_data["cbm"], + vbm=bs_data["vbm"], + efermi=bs_data["efermi"], + is_gap_direct=bs_data["is_gap_direct"], + is_metal=bs_data["is_metal"], + magnetic_ordering=magnetic_ordering, + bandstructure=bs_data["bandstructure"], + dos=dos_data["dos_entry"], + origins=origins, + **kwargs, + ), + structures, + bandstructures, + dos, + ) - is_gap_direct = is_gap_direct - is_metal = is_metal - structure = dos_obj.structure +def _generate_bs_data( + bandstructures: BSShim, + origins: list[PropertyOrigin], + structures: dict[IdentifierType, Structure], +) -> dict: + bs_data = { # type: ignore + "setyawan_curtarolo": bandstructures.setyawan_curtarolo, + "hinuma": bandstructures.hinuma, + "latimer_munro": bandstructures.latimer_munro, + } + + bs_type: str + bs_input: tuple[IdentifierType, BandStructureSymmLine, int] + bs_task_id: IdentifierType + bs: BandStructureSymmLine + + for bs_type, bs_input in bs_data.items(): + if bs_input is not None: + bs_task_id, bs, _ = bs_input + bs_mag_ordering = CollinearMagneticStructureAnalyzer( + structures[bs_task_id] + ).ordering + + gap_dict = bs.get_band_gap() + is_metal = bs.is_metal() + direct_gap = bs.get_direct_band_gap() + + if is_metal: + band_gap = 0.0 + cbm = None # type: ignore[assignment] + vbm = None # type: ignore[assignment] + is_gap_direct = False + else: + band_gap = gap_dict["energy"] + cbm = bs.get_cbm() # type: ignore[assignment] + vbm = bs.get_vbm() # type: ignore[assignment] + is_gap_direct = gap_dict["direct"] + + # coerce type here, mixture of str and int types in bs objects + cbm["kpoint_index"] = [int(x) for x in cbm["kpoint_index"]] # type: ignore[index] + vbm["kpoint_index"] = [int(x) for x in vbm["kpoint_index"]] # type: ignore[index] + + bs_efermi = bs.efermi + nbands = bs.nb_bands + + bs_data[bs_type] = BandStructureSummaryData( # type: ignore + task_id=bs_task_id, + band_gap=band_gap, + direct_gap=direct_gap, + cbm=cbm, + vbm=vbm, + is_gap_direct=is_gap_direct, + is_metal=is_metal, + efermi=bs_efermi, + nbands=nbands, + magnetic_ordering=bs_mag_ordering, + ) - if structures is not None and structures[dos_task]: - structure = structures[dos_task] + def _bs_eval( + bs_data: dict[str, BandStructureSymmLine | None], + bs_rank: list[str] = ["latimer_munro", "hinuma", "setyawan_curtarolo"], + ) -> str: + for bs_type in bs_rank: + if bs_data[bs_type] is not None: + yield bs_type + + blessed_bs_key = next(_bs_eval(bs_data)) + + bs_entry = BandstructureData(**bs_data) # type: ignore + band_gap = getattr(bs_entry, blessed_bs_key).band_gap + cbm = (getattr(bs_entry, blessed_bs_key).cbm or {}).get("energy", None) # type: ignore + vbm = (getattr(bs_entry, blessed_bs_key).vbm or {}).get("energy", None) # type: ignore + efermi = getattr(bs_entry, blessed_bs_key).efermi # type: ignore + is_gap_direct = getattr(bs_entry, blessed_bs_key).is_gap_direct # type: ignore + is_metal = getattr(bs_entry, blessed_bs_key).is_metal # type: ignore + + es_origins_from_bs = None + for origin in origins: + if origin.name == blessed_bs_key: + es_origins_from_bs = PropertyOrigin( + name="electronic_structure", + last_updated=origin.last_updated, + task_id=origin.task_id, + ) - dos_mag_ordering = CollinearMagneticStructureAnalyzer(structure).ordering + bs_magnetic_ordering = CollinearMagneticStructureAnalyzer( + structures[es_origins_from_bs.task_id], + round_magmoms=True, + threshold_nonmag=0.2, + threshold=0, + ).ordering + + return { + "band_gap": band_gap, + "cbm": cbm, + "vbm": vbm, + "efermi": efermi, + "is_gap_direct": is_gap_direct, + "is_metal": is_metal, + "bs_magnetic_ordering": bs_magnetic_ordering, + "bandstructure": bs_entry, + "es_origins_from_bs": es_origins_from_bs, + } + + +def _generate_dos_data( + dos: DosShim, + origins: list[PropertyOrigin], + structures: dict[IdentifierType, Structure], +) -> dict: + dos_task, dos_obj, _ = dos.dos + + orbitals = [OrbitalType.s, OrbitalType.p, OrbitalType.d] + spins = list(dos_obj.densities.keys()) + + ele_dos = dos_obj.get_element_dos() + tot_orb_dos = dos_obj.get_spd_dos() + + elements = ele_dos.keys() + + dos_efermi = dos_obj.efermi + structure = structures[dos_task] + + dos_magnetic_ordering = CollinearMagneticStructureAnalyzer(structure).ordering + + dos_data = { + "task_id": dos_task, + "total": defaultdict(dict), + "elemental": {element: defaultdict(dict) for element in elements}, + "orbital": defaultdict(dict), + "magnetic_ordering": dos_magnetic_ordering, + } + + for spin in spins: + # - Process total DOS data + band_gap = dos_obj.get_gap(spin=spin) + (cbm, vbm) = dos_obj.get_cbm_vbm(spin=spin) + + try: + spin_polarization = dos_obj.spin_polarization + if spin_polarization is None or isnan(spin_polarization): + spin_polarization = None + except KeyError: + spin_polarization = None + + dos_data["total"][spin] = DosSummaryData( # type: ignore[index] + band_gap=band_gap, + cbm=cbm, + vbm=vbm, + efermi=dos_efermi, + spin_polarization=spin_polarization, + ) - dos_data = { - "total": defaultdict(dict), - "elemental": {element: defaultdict(dict) for element in elements}, - "orbital": defaultdict(dict), - "magnetic_ordering": dos_mag_ordering, - } + # - Process total orbital projection data + for orbital in orbitals: + band_gap = tot_orb_dos[orbital].get_gap(spin=spin) - for spin in spins: - # - Process total DOS data - band_gap = dos_obj.get_gap(spin=spin) - (cbm, vbm) = dos_obj.get_cbm_vbm(spin=spin) + (cbm, vbm) = tot_orb_dos[orbital].get_cbm_vbm(spin=spin) - try: - spin_polarization = dos_obj.spin_polarization - if spin_polarization is None or isnan(spin_polarization): - spin_polarization = None - except KeyError: - spin_polarization = None + spin_polarization = None - dos_data["total"][spin] = DosSummaryData( # type: ignore[index] - task_id=dos_task, + dos_data["orbital"][str(orbital)][spin] = DosSummaryData( # type: ignore[index] band_gap=band_gap, cbm=cbm, vbm=vbm, @@ -354,16 +555,25 @@ def from_bsdos( # type: ignore[override] spin_polarization=spin_polarization, ) - # - Process total orbital projection data - for orbital in orbitals: - band_gap = tot_orb_dos[orbital].get_gap(spin=spin) + # - Process element and element orbital projection data + for ele in ele_dos: + orb_dos = dos_obj.get_element_spd_dos(ele) + + for orbital in ["total"] + list(orb_dos.keys()): # type: ignore[assignment] + if orbital == "total": + proj_dos = ele_dos + label = ele + else: + proj_dos = orb_dos + label = orbital - (cbm, vbm) = tot_orb_dos[orbital].get_cbm_vbm(spin=spin) + for spin in spins: + band_gap = proj_dos[label].get_gap(spin=spin) + (cbm, vbm) = proj_dos[label].get_cbm_vbm(spin=spin) spin_polarization = None - dos_data["orbital"][str(orbital)][spin] = DosSummaryData( # type: ignore[index] - task_id=dos_task, + dos_data["elemental"][ele][str(orbital)][spin] = DosSummaryData( # type: ignore[index] band_gap=band_gap, cbm=cbm, vbm=vbm, @@ -371,207 +581,166 @@ def from_bsdos( # type: ignore[override] spin_polarization=spin_polarization, ) - # - Process element and element orbital projection data - for ele in ele_dos: - orb_dos = dos_obj.get_element_spd_dos(ele) - - for orbital in ["total"] + list(orb_dos.keys()): # type: ignore[assignment] - if orbital == "total": - proj_dos = ele_dos - label = ele - else: - proj_dos = orb_dos - label = orbital - - for spin in spins: - band_gap = proj_dos[label].get_gap(spin=spin) - (cbm, vbm) = proj_dos[label].get_cbm_vbm(spin=spin) - - spin_polarization = None - - dos_data["elemental"][ele][str(orbital)][spin] = DosSummaryData( # type: ignore[index] - task_id=dos_task, - band_gap=band_gap, - cbm=cbm, - vbm=vbm, - efermi=dos_efermi, - spin_polarization=spin_polarization, - ) - - # -- Process band structure data - bs_data = { # type: ignore - "setyawan_curtarolo": setyawan_curtarolo, - "hinuma": hinuma, - "latimer_munro": latimer_munro, - } + dos_entry = DosData(**dos_data) # type: ignore[arg-type] - for bs_type, bs_input in bs_data.items(): - if bs_input is not None: - bs_task, bs = list(bs_input.items())[0] - - if structures is not None and structures[bs_task]: - bs_mag_ordering = CollinearMagneticStructureAnalyzer( - structures[bs_task] - ).ordering - else: - bs_mag_ordering = CollinearMagneticStructureAnalyzer( - bs.structure # type: ignore[arg-type] - ).ordering - - gap_dict = bs.get_band_gap() - is_metal = bs.is_metal() - direct_gap = bs.get_direct_band_gap() - - if is_metal: - band_gap = 0.0 - cbm = None # type: ignore[assignment] - vbm = None # type: ignore[assignment] - is_gap_direct = False - else: - band_gap = gap_dict["energy"] - cbm = bs.get_cbm() # type: ignore[assignment] - vbm = bs.get_vbm() # type: ignore[assignment] - is_gap_direct = gap_dict["direct"] - - # coerce type here, mixture of str and int types in bs objects - cbm["kpoint_index"] = [int(x) for x in cbm["kpoint_index"]] # type: ignore[index] - vbm["kpoint_index"] = [int(x) for x in vbm["kpoint_index"]] # type: ignore[index] - - bs_efermi = bs.efermi - nbands = bs.nb_bands - - # - Get equivalent labels between different conventions - hskp = HighSymmKpath( - bs.structure, - path_type="all", - symprec=0.1, - angle_tolerance=5, - atol=1e-5, - ) - equivalent_labels = hskp.equiv_labels - - if bs_type == "latimer_munro": - gen_labels = set( - [ - label - for label in equivalent_labels["latimer_munro"][ - "setyawan_curtarolo" - ] - ] - ) - kpath_labels = set( - [ - kpoint.label - for kpoint in bs.kpoints - if kpoint.label is not None - ] - ) - - if not gen_labels.issubset(kpath_labels): - new_structure = SpacegroupAnalyzer( - bs.structure # type: ignore[arg-type] - ).get_primitive_standard_structure( - international_monoclinic=False - ) - - hskp = HighSymmKpath( - new_structure, - path_type="all", - symprec=SETTINGS.SYMPREC, - angle_tolerance=SETTINGS.ANGLE_TOL, - atol=1e-5, - ) - equivalent_labels = hskp.equiv_labels - - bs_data[bs_type] = BandStructureSummaryData( # type: ignore - task_id=bs_task, - band_gap=band_gap, - direct_gap=direct_gap, - cbm=cbm, - vbm=vbm, - is_gap_direct=is_gap_direct, - is_metal=is_metal, - efermi=bs_efermi, - nbands=nbands, - equivalent_labels=equivalent_labels, - magnetic_ordering=bs_mag_ordering, - ) + dos_cbm, dos_vbm = dos_obj.get_cbm_vbm() + dos_gap = max(dos_cbm - dos_vbm, 0.0) - bs_entry = BandstructureData(**bs_data) # type: ignore - dos_entry = DosData(**dos_data) # type: ignore[arg-type] + is_metal = True if np.isclose(dos_gap, 0.0, atol=0.01, rtol=0) else False - # Obtain summary data + es_origins_from_dos = None + for origin in origins: + if origin.task_id == dos_task: + es_origins_from_dos = PropertyOrigin( + name="electronic_structure", + last_updated=origin.last_updated, + task_id=dos_task, + ) - bs_gap = ( - bs_entry.setyawan_curtarolo.band_gap - if bs_entry.setyawan_curtarolo is not None - else None + return { + "band_gap": dos_gap, + "cbm": dos_cbm, + "vbm": dos_vbm, + "efermi": dos_efermi, + "is_metal": is_metal, + "dos_magnetic_ordering": dos_magnetic_ordering, + "es_origins_from_dos": es_origins_from_dos, + "dos_entry": dos_entry, + } + + +def bs_checks( + doc: ElectronicStructureDoc, + structures: dict[str, Structure], + bandstructures: BSShim, + skip_primitive_check: bool = False, +) -> ElectronicStructureDoc: + for _, bs_summary in doc.bandstructure: + if bs_summary is not None: + _bandgap_diff_check(doc, bs_summary.band_gap, bs_summary.task_id) + + mag_orderings: list[tuple[str, Ordering]] = [ + (bs_summary.task_id, bs_summary.magnetic_ordering) + for _, bs_summary in doc.bandstructure + if bs_summary is not None + ] + + _magnetic_ordering_check(doc, mag_orderings) + + if not skip_primitive_check: + _structure_primitive_checks(doc, structures) + + for _, bandstructure in bandstructures: + if bandstructure is not None: + task_id, _, lmaxmix = bandstructure + _lmaxmix_check(doc, structures[task_id], lmaxmix, task_id) + + return doc + + +def dos_checks( + doc: ElectronicStructureDoc, + structures: dict[str, Structure], + dos: DosShim, + skip_primitive_check: bool = False, +) -> ElectronicStructureDoc: + _bandgap_diff_check( + doc, + doc.dos.total[Spin.up].band_gap, + doc.dos.task_id, + ) + + mag_orderings: list[tuple[str, Ordering]] = [ + ( + doc.dos.task_id, + doc.dos.magnetic_ordering, ) - dos_cbm, dos_vbm = dos_obj.get_cbm_vbm() - dos_gap = max(dos_cbm - dos_vbm, 0.0) - - new_origin_last_updated = None - new_origin_task_id = None - - if bs_gap is not None and bs_gap <= dos_gap + 0.2: - summary_task = bs_entry.setyawan_curtarolo.task_id # type: ignore - summary_band_gap = bs_gap - summary_cbm = ( - bs_entry.setyawan_curtarolo.cbm.get("energy", None) # type: ignore - if bs_entry.setyawan_curtarolo.cbm is not None # type: ignore - else None + ] + + _magnetic_ordering_check(doc, mag_orderings) + + if not skip_primitive_check: + _structure_primitive_checks(doc, structures) + + task_id, dos_obj, lmaxmix = dos.dos + _lmaxmix_check(doc, structures[task_id], lmaxmix, doc.dos.task_id) + + return doc + + +def bsdos_checks( + doc: ElectronicStructureDoc, + structures: dict[str, Structure], + bandstructures: BSShim, + dos: DosShim, +) -> ElectronicStructureDoc: + _structure_primitive_checks(doc, structures) + return dos_checks( + bs_checks( + doc, + structures, + bandstructures, + skip_primitive_check=True, + ), + structures, + dos, + skip_primitive_check=True, + ) + + +def _bandgap_diff_check( + doc: ElectronicStructureDoc, band_gap: float, task_id: IdentifierType +) -> None: + if abs(doc.band_gap - band_gap) > 0.25: + doc.warnings.append( + "Absolute difference between blessed band gap and the band gap for" + f"task {str(task_id)} is larger than 0.25 eV.", + ) + + +def _magnetic_ordering_check( + doc: ElectronicStructureDoc, mag_orderings: list[tuple[IdentifierType, Ordering]] +) -> None: + for task_id, ordering in mag_orderings: + if doc.magnetic_ordering != ordering: + doc.warnings.append( + f"Summary data magnetic ordering does not agree with the ordering from {str(task_id)}" ) - summary_vbm = ( - bs_entry.setyawan_curtarolo.vbm.get("energy", None) # type: ignore - if bs_entry.setyawan_curtarolo.cbm is not None # type: ignore - else None - ) # type: ignore - summary_efermi = bs_entry.setyawan_curtarolo.efermi # type: ignore - is_gap_direct = bs_entry.setyawan_curtarolo.is_gap_direct # type: ignore - is_metal = bs_entry.setyawan_curtarolo.is_metal # type: ignore - - for origin in origins: - if origin["name"] == "setyawan_curtarolo": - new_origin_last_updated = origin["last_updated"] - new_origin_task_id = origin["task_id"] - - else: - summary_task = dos_entry.model_dump()["total"][str(Spin.up)]["task_id"] - summary_band_gap = dos_gap - summary_cbm = dos_cbm - summary_vbm = dos_vbm - summary_efermi = dos_efermi - is_metal = True if np.isclose(dos_gap, 0.0, atol=0.01, rtol=0) else False - - for origin in origins: - if origin["name"] == "dos": - new_origin_last_updated = origin["last_updated"] - new_origin_task_id = origin["task_id"] - - if new_origin_task_id is not None: - for origin in origins: - if origin["name"] == "electronic_structure": - origin["last_updated"] = new_origin_last_updated - origin["task_id"] = new_origin_task_id - - summary_magnetic_ordering = CollinearMagneticStructureAnalyzer( - kwargs["meta_structure"], - round_magmoms=True, - threshold_nonmag=0.2, - threshold=0, - ).ordering - - return cls.from_structure( - material_id=AlphaID(material_id), - task_id=summary_task, - band_gap=summary_band_gap, - cbm=summary_cbm, - vbm=summary_vbm, - efermi=summary_efermi, - is_gap_direct=is_gap_direct, - is_metal=is_metal, - magnetic_ordering=summary_magnetic_ordering, - bandstructure=bs_entry, - dos=dos_entry, - **kwargs, + + +def _lmaxmix_check( + doc: ElectronicStructureDoc, + structure: Structure, + lmaxmix: int, + task_id: IdentifierType, +) -> None: + # VASP default LMAXMIX is 2 + expected_lmaxmix = MPStaticSet(structure).incar.get("LMAXMIX", 2) + if lmaxmix != expected_lmaxmix: + doc.warnings.append( + "An incorrect calculation parameter may lead to errors in the band gap of " + f"0.1-0.2 eV (LMAXIX is {lmaxmix} and should be {expected_lmaxmix} for " + f"{str(task_id)})." ) + + +def _structure_primitive_checks( + doc: ElectronicStructureDoc, structures: dict[IdentifierType, Structure] +) -> None: + for task_id, struct in structures.items(): + struct_prim = SpacegroupAnalyzer(struct).get_primitive_standard_structure( + international_monoclinic=False + ) + + if not np.allclose( + struct.lattice.matrix, struct_prim.lattice.matrix, atol=1e-3 + ): + if np.isclose(struct_prim.volume, struct.volume, atol=5, rtol=0): + doc.warnings.append( + f"The input structure for {str(task_id)} is primitive but may not exactly match the " + f"standard primitive setting." + ) + else: + doc.warnings.append( + f"The input structure for {str(task_id)} does not match the expected standard primitive" + ) diff --git a/emmet-core/emmet/core/material.py b/emmet-core/emmet/core/material.py index 5e1bbf2c05..5d79d753c0 100644 --- a/emmet-core/emmet/core/material.py +++ b/emmet-core/emmet/core/material.py @@ -16,7 +16,7 @@ MoleculeType, StructureType, ) -from emmet.core.types.typing import DateTimeType, IdentifierType +from emmet.core.types.typing import DateTimeType, IdentifierType, MaterialIdentifierType if TYPE_CHECKING: @@ -46,7 +46,7 @@ class BasePropertyMetadata(StructureMetadata, EmmetBaseModel): extended, not used directly """ - material_id: IdentifierType | None = Field( + material_id: MaterialIdentifierType | None = Field( None, description="The Materials Project ID of the material, used as a universal reference across property documents." "This comes in the form: mp-******.", diff --git a/emmet-core/emmet/core/oxidation_states.py b/emmet-core/emmet/core/oxidation_states.py index fcecdb22e2..9ff8586305 100644 --- a/emmet-core/emmet/core/oxidation_states.py +++ b/emmet-core/emmet/core/oxidation_states.py @@ -2,12 +2,12 @@ import logging from collections import defaultdict +from copy import deepcopy from typing import TYPE_CHECKING import numpy as np from pydantic import Field from pymatgen.analysis.bond_valence import BVAnalyzer -from pymatgen.core import Structure from pymatgen.core.periodic_table import Specie from emmet.core.material_property import PropertyDoc @@ -49,9 +49,11 @@ class OxidationStateDoc(PropertyDoc): @classmethod def from_structure( - cls, structure: Structure, material_id: IdentifierType | None = None, **kwargs + cls, + structure: StructureType, + material_id: IdentifierType | None = None, + **kwargs, ): - # Check if structure already has oxidation states, # if so pass this along unchanged with "method" == "manualx" struct_valences: list[float | None] = [] @@ -101,11 +103,20 @@ def from_structure( } if d["method"] == OxiStateAssigner.BVA: + # BVAnalyzer (through SpaceGroupAnalyzer) is sensitive to magnetic configuration + # -> magmoms can be removed for improved reliablity during oxi state analysis, + # but original structure should be passed as meta_structure to preserve data + if "magmom" in structure.site_properties: + meta_structure = deepcopy(structure) + structure.remove_site_property("magmom") + else: + meta_structure = structure + try: bva = BVAnalyzer() valences = bva.get_valences(structure) possible_species = { - str(Specie(structure[idx].specie, oxidation_state=valence)) + str(Specie(str(structure[idx].specie), oxidation_state=valence)) for idx, valence in enumerate(valences) } @@ -166,7 +177,7 @@ def from_structure( d["method"] = None return super().from_structure( - meta_structure=structure, + meta_structure=meta_structure, **d, **kwargs, ) diff --git a/emmet-core/emmet/core/provenance.py b/emmet-core/emmet/core/provenance.py index 7bf32c4b81..f3c61d78fa 100644 --- a/emmet-core/emmet/core/provenance.py +++ b/emmet-core/emmet/core/provenance.py @@ -4,19 +4,20 @@ import json import warnings -from typing import TYPE_CHECKING, Any, Annotated +from typing import TYPE_CHECKING, Annotated, Any, Literal from pybtex.database import BibliographyData, parse_string from pybtex.errors import set_strict_mode -from pydantic import BaseModel, Field, BeforeValidator -from pymatgen.core import Lattice, Structure, PeriodicSite +from pydantic import BaseModel, BeforeValidator, Field +from pymatgen.core import Lattice, PeriodicSite, Structure from emmet.core.material_property import PropertyDoc from emmet.core.math import Matrix3D +from emmet.core.structure import StructureMetadata from emmet.core.symmetry import SymmetryData from emmet.core.types.enums import ValueEnum -from emmet.core.types.pymatgen_types.structure_adapter import StructureType from emmet.core.types.pymatgen_types.lattice_adapter import LatticeType +from emmet.core.types.pymatgen_types.structure_adapter import StructureType from emmet.core.types.typing import DateTimeType, IdentifierType if TYPE_CHECKING: @@ -282,6 +283,10 @@ class SNLAbout(BaseModel): def migrate_legacy_data(cls, config: dict[str, Any]) -> Self: """Migrate legacy SNL data with free-form JSON values to schematized.""" config["history"] = _migrate_legacy_history_data(config.get("history", [])) + if projs := config.pop("projects", None): + if "tags" not in config: + config["tags"] = [] + config["tags"].extend(projs) return cls(**config) @@ -419,3 +424,78 @@ def migrate_legacy_data(cls, config: dict[str, Any]) -> Self: """Migrate legacy provenance data with free-form JSON values to schematized.""" config["history"] = _migrate_legacy_history_data(config.get("history", [])) return cls(**config) + + +class DatabaseSNL(StructureMetadata): + """Define schemas for database entries. + + This particular SNL schema is used for + experimental databases like ICSD and Pauling File. + """ + + snl_id: str | None = Field(None, description="The SNL ID for this entry") + structure: StructureType | None = Field( + None, description="The structure for this entry" + ) + about: SNLAbout | None = Field( + None, description="Extended metadata for this entry." + ) + theoretical: bool = Field( + True, description="Whether this entry is a theoretical database entry." + ) + is_ordered: bool | None = Field( + None, + description="Whether this represents a (configurationally) ordered structure.", + ) + last_updated: DateTimeType = Field( + description="The last time this entry was updated." + ) + tags: list[str] | None = Field( + None, description="List of high-level metadata for this entry." + ) + source: Literal["icsd", "pauling", "mp-complete", "user"] | None = Field( + None, description="The source of this SNL." + ) + submission_id: int | None = Field( + None, description="If applicable, the identifier of the submitted structure." + ) + submitter_email: str | None = Field( + None, + description="If applicable, the email of the user who submitted the structure.", + ) + + @classmethod + def migrate_legacy_config(cls, config: dict) -> Self: + """Migrate legacy, JSONL-format SNLs to the current schema. + + Legacy database SNLs appear to extend the properties of the + pymatgen Structure object. + """ + if all( + config.get(k) + for k in ( + "sites", + "lattice", + ) + ): + config["structure"] = Structure.from_dict( + { + k: config.pop(k, None) + for k in ( + "sites", + "lattice", + ) + } + ) + if "structure" in config and "is_ordered" not in config: + config["is_ordered"] = config["structure"].is_ordered + + if (expt := config.pop("experimental", None)) is not None: + config["theoretical"] = not expt + + if "structure" in config: + return cls.from_structure( + meta_structure=config["structure"], + **config, + ) + return cls(**config) diff --git a/emmet-core/emmet/core/similarity.py b/emmet-core/emmet/core/similarity.py index cefb84d2f3..8d814133d4 100644 --- a/emmet-core/emmet/core/similarity.py +++ b/emmet-core/emmet/core/similarity.py @@ -189,8 +189,8 @@ def _featurize_structure(self, structure: Structure) -> np.ndarray: """ raise NotImplementedError + @staticmethod def _post_process_distance( - self, distances: np.ndarray, ) -> np.ndarray: """Postprocess vector distances to yield consistent similarity scores. @@ -242,8 +242,9 @@ def featurize_structures( return np.array(_feature_vectors) + @classmethod def _get_closest_vectors( - self, idx: int, v: np.ndarray, num: int + cls, idx: int, v: np.ndarray, num: int ) -> tuple[np.ndarray, np.ndarray]: """Return only a subset of vectors most similar to a specified vector. @@ -260,7 +261,7 @@ def _get_closest_vectors( subset_dist = dist[idxs] sorted_subset_idx = np.argsort(subset_dist) - return idxs[sorted_subset_idx], self._post_process_distance( + return idxs[sorted_subset_idx], cls._post_process_distance( subset_dist[sorted_subset_idx] ) @@ -472,7 +473,8 @@ def _featurize_structure(self, structure: Structure) -> np.ndarray: except Exception: return np.nan * np.ones(self.num_feature) - def _post_process_distance(self, distances: np.ndarray) -> np.ndarray: + @staticmethod + def _post_process_distance(distances: np.ndarray) -> np.ndarray: """Use exponential weighting of feature vector distances. Parameters @@ -501,7 +503,35 @@ def __init__(self, model: str | Path = "M3GNet-MP-2018.6.1-Eform"): if matgl is None: raise ValueError("`pip install matgl` to use these features.") - self.model = matgl.load_model(Path(model)).model + self._model_path = Path(model) + self._model = None + + @staticmethod + def _load_model(model_path: str | Path): + return matgl.load_model(Path(model_path)).model + + @property + def model(self): + """Return the matgl model.""" + if self._model_path and self._model is None: + self._model = self._load_model(self._model_path) + return self._model + + @staticmethod + def _model_readout_from_structure( + structure: Structure, + model, + ) -> np.ndarray: + """Featurize a structure using a given matgl model.""" + try: + return ( + model.predict_structure(structure, return_features=True)["readout"] + .detach() + .cpu() + .numpy()[0] + ) + except Exception: + return np.nan * np.ones(128) def _featurize_structure(self, structure: Structure) -> np.ndarray: """Featurize a single structure using M3GNet-Eform. @@ -514,8 +544,68 @@ def _featurize_structure(self, structure: Structure) -> np.ndarray: ----------- np.ndarray """ - results = self.model.predict_structure(structure, return_features=True) - return results["readout"].detach().cpu().numpy()[0] + return self._model_readout_from_structure(structure, self.model) + + def _featurize_structures( + self, + structures: dict[int, Structure], + model_path: Path, + fvd: dict[int, np.ndarray], + ): + model = self._load_model(model_path) + new_fvs = { + idx: self._model_readout_from_structure(structure, model) + for idx, structure in structures.items() + } + fvd.update(new_fvs) + + def featurize_structures( + self, + structures: list[Structure], + num_procs: int = 1, + ): + """Featurize structures using the user-defined _featurize_structure. + + Rewritten because of pickling issues with torch objects. + + Parameters + ----------- + structures : list of Structure objects + num_procs : int = 1 + Number of parallel processes to run in featurizing structures. + + Returns + ----------- + np.ndarray : the feature vectors of the input structures. + """ + if num_procs > 1: + manager = multiprocessing.Manager() + fvd = manager.dict() + procs = [] + num_batches = int(np.ceil(len(structures) / num_procs)) + for i in range(num_procs): + idxs = (i * num_batches, min((i + 1) * num_batches, len(structures))) + proc = multiprocessing.Process( + target=self._featurize_structures, + args=( + {idx: structures[idx] for idx in range(*idxs)}, + self._model_path, + fvd, + ), + ) + proc.start() + procs.append(proc) + + for proc in procs: + proc.join() + + _feature_vectors = [fvd[idx] for idx in range(len(structures))] + else: + _feature_vectors = [ + self._featurize_structure(structure) for structure in structures + ] + + return np.array(_feature_vectors) class SimilarityEntry(BaseModel): diff --git a/emmet-core/emmet/core/summary.py b/emmet-core/emmet/core/summary.py index 6e414bd177..520fa3b9f4 100644 --- a/emmet-core/emmet/core/summary.py +++ b/emmet-core/emmet/core/summary.py @@ -1,8 +1,10 @@ from __future__ import annotations +from collections import ChainMap +from itertools import chain from typing import TYPE_CHECKING -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr from emmet.core.electronic_structure import BandstructureData, DosData from emmet.core.material_property import PropertyDoc @@ -18,16 +20,14 @@ class HasProps(ValueEnum): """ - Enum of possible hasprops values. + Enum of possible has_props values. """ materials = "materials" thermo = "thermo" xas = "xas" grain_boundaries = "grain_boundaries" - chemenv = "chemenv" electronic_structure = "electronic_structure" - absorption = "absorption" bandstructure = "bandstructure" dos = "dos" magnetism = "magnetism" @@ -40,8 +40,10 @@ class HasProps(ValueEnum): charge_density = "charge_density" eos = "eos" phonon = "phonon" + absorption = "absorption" insertion_electrodes = "insertion_electrodes" substrates = "substrates" + chemenv = "chemenv" class SummaryStats(BaseModel): @@ -89,486 +91,341 @@ class SummaryStats(BaseModel): ) -class XASSearchData(BaseModel): - """ - Fields in XAS sub docs in summary - """ +class PropModel(BaseModel): + """Check for model initialization outside of defaults.""" - edge: XasEdge | None = Field( - None, - title="Absorption Edge", - description="The interaction edge for XAS", - ) - absorbing_element: ElementType | None = Field( - None, - description="Absorbing element.", - ) + _prop: str = PrivateAttr("material") - spectrum_type: XasType | None = Field( - None, - description="Type of XAS spectrum.", - ) + @property + def _has_props(self) -> bool: + return not not self.model_fields_set + @property + def name(self) -> HasProps: + return HasProps[self._prop] -class GBSearchData(BaseModel): - """ - Fields in grain boundary sub docs in summary - """ - sigma: int | None = Field( - None, - description="Sigma value of the boundary.", - ) - - type: str | None = Field( - None, - description="Grain boundary type.", - ) - - gb_energy: float | None = Field( - None, - description="Grain boundary energy in J/m^2.", - ) - - rotation_angle: float | None = Field( - None, - description="Rotation angle in degrees.", - ) - - -class SummaryDoc(PropertyDoc): - """ - Summary information about materials and their properties, useful for materials - screening studies and searching. - """ - - property_name: str = "summary" - - # Materials +class MaterialsSummary(PropertyDoc, PropModel): task_ids: list[IdentifierType] = Field( [], title="Calculation IDs", description="List of Calculations IDs associated with this material.", ) - structure: StructureType = Field( - ..., - description="The lowest energy structure for this material.", + ..., description="The lowest energy structure for this material." ) - # Thermo +class ThermoSummary(PropModel): + _prop: str = PrivateAttr("thermo") uncorrected_energy_per_atom: float | None = Field( None, description="The total DFT energy of this material per atom in eV/atom.", ) - energy_per_atom: float | None = Field( None, description="The total corrected DFT energy of this material per atom in eV/atom.", ) - formation_energy_per_atom: float | None = Field( None, description="The formation energy per atom in eV/atom.", ) - energy_above_hull: float | None = Field( None, description="The energy above the hull in eV/Atom.", ) - is_stable: bool = Field( False, description="Flag for whether this material is on the hull and therefore stable.", ) - equilibrium_reaction_energy_per_atom: float | None = Field( None, description="The reaction energy of a stable entry from the neighboring equilibrium stable materials in eV." " Also known as the inverse distance to hull.", ) - decomposes_to: list[DecompositionProduct] | None = Field( None, description="List of decomposition data for this material. Only valid for metastable or unstable material.", ) - # XAS - xas: list[XASSearchData] | None = Field( - None, - description="List of xas documents.", +class XASSearchData(PropModel): + """ + Fields in XAS sub docs in summary + """ + + edge: XasEdge | None = Field( + None, title="Absorption Edge", description="The interaction edge for XAS" + ) + absorbing_element: ElementType | None = Field( + None, description="Absorbing element." ) + spectrum_type: XasType | None = Field(None, description="Type of XAS spectrum.") - # GB - grain_boundaries: list[GBSearchData] | None = Field( +class XASSummary(PropModel): + _prop: str = PrivateAttr("xas") + xas: list[XASSearchData] | None = Field( None, - description="List of grain boundary documents.", + description="List of xas documents.", ) - # Electronic Structure - band_gap: float | None = Field( - None, - description="Band gap energy in eV.", - ) +class GBSearchData(BaseModel): + """ + Fields in grain boundary sub docs in summary + """ - cbm: float | None = Field( - None, - description="Conduction band minimum data.", - ) + sigma: int | None = Field(None, description="Sigma value of the boundary.") + type: str | None = Field(None, description="Grain boundary type.") + gb_energy: float | None = Field(None, description="Grain boundary energy in J/m^2.") + rotation_angle: float | None = Field(None, description="Rotation angle in degrees.") - vbm: float | None = Field( - None, - description="Valence band maximum data.", - ) - efermi: float | None = Field( - None, - description="Fermi energy in eV.", +class GBSummary(BaseModel): + _prop: str = PrivateAttr("grain_boundaries") + grain_boundaries: list[GBSearchData] | None = Field( + None, description="List of grain boundary documents." ) - is_gap_direct: bool | None = Field( - None, - description="Whether the band gap is direct.", - ) - is_metal: bool | None = Field( - None, - description="Whether the material is a metal.", +class ElectronicStructureSummary(PropModel): + _prop: str = PrivateAttr("electronic_structure") + band_gap: float | None = Field(None, description="Band gap energy in eV.") + cbm: float | None = Field(None, description="Conduction band minimum data.") + vbm: float | None = Field(None, description="Valence band maximum data.") + efermi: float | None = Field(None, description="Fermi energy in eV.") + is_gap_direct: bool | None = Field( + None, description="Whether the band gap is direct." ) + is_metal: bool | None = Field(None, description="Whether the material is a metal.") - es_source_calc_id: IdentifierType | None = Field( - None, - description="The source calculation ID for the electronic structure data.", - ) +class BandstructureSummary(PropModel): + _prop: str = PrivateAttr("bandstructure") bandstructure: BandstructureData | None = Field( - None, - description="Band structure data for the material.", - ) - - dos: DosData | None = Field( - None, - description="Density of states data for the material.", + None, description="Band structure data for the material." ) - # DOS - dos_energy_up: float | None = Field( - None, - description="Spin-up DOS band gap in eV.", +class DosSummary(PropModel): + _prop: str = PrivateAttr("dos") + dos: DosData | None = Field( + None, description="Density of states data for the material." ) - + dos_energy_up: float | None = Field(None, description="Spin-up DOS band gap in eV.") dos_energy_down: float | None = Field( - None, - description="Spin-down DOS band gap in eV.", + None, description="Spin-down DOS band gap in eV." ) - # Magnetism +class MagnetismSummary(PropModel): + _prop: str = PrivateAttr("magnetism") is_magnetic: bool | None = Field( - None, - description="Whether the material is magnetic.", + None, description="Whether the material is magnetic." ) - - ordering: str | None = Field( - None, - description="Type of magnetic ordering.", - ) - + ordering: str | None = Field(None, description="Type of magnetic ordering.") total_magnetization: float | None = Field( - None, - description="Total magnetization in μB.", + None, description="Total magnetization in μB." ) - total_magnetization_normalized_vol: float | None = Field( - None, - description="Total magnetization normalized by volume in μB/ų.", + None, description="Total magnetization normalized by volume in μB/ų." ) - total_magnetization_normalized_formula_units: float | None = Field( - None, - description="Total magnetization normalized by formula unit in μB/f.u. .", + None, description="Total magnetization normalized by formula unit in μB/f.u. ." ) - num_magnetic_sites: int | None = Field( - None, - description="The number of magnetic sites.", + None, description="The number of magnetic sites." ) - num_unique_magnetic_sites: int | None = Field( - None, - description="The number of unique magnetic sites.", + None, description="The number of unique magnetic sites." ) - types_of_magnetic_species: list[ElementType] | None = Field( - None, - description="Magnetic specie elements.", + None, description="Magnetic specie elements." ) - # Elasticity +class ElasticitySummary(PropModel): + _prop: str = PrivateAttr("elasticity") # k_voigt: float | None = Field(None, description="Voigt average of the bulk modulus.") - # k_reuss: float | None = Field(None, description="Reuss average of the bulk modulus in GPa.") - # k_vrh: float | None = Field(None, description="Voigt-Reuss-Hill average of the bulk modulus in GPa.") - # g_voigt: float | None = Field(None, description="Voigt average of the shear modulus in GPa.") - # g_reuss: float | None = Field(None, description="Reuss average of the shear modulus in GPa.") - # g_vrh: float | None = Field(None, description="Voigt-Reuss-Hill average of the shear modulus in GPa.") - bulk_modulus: dict[str, float] | None = Field( None, description="Voigt, Reuss, and Voigt-Reuss-Hill averages of the bulk modulus in GPa.", ) - shear_modulus: dict[str, float] | None = Field( None, description="Voigt, Reuss, and Voigt-Reuss-Hill averages of the shear modulus in GPa.", ) - universal_anisotropy: float | None = Field(None, description="Elastic anisotropy.") - homogeneous_poisson: float | None = Field(None, description="Poisson's ratio.") - # Dielectric and Piezo - - e_total: float | None = Field( - None, - description="Total dielectric constant.", - ) +class DielectricSummary(PropModel): + _prop: str = PrivateAttr("dielectric") + e_total: float | None = Field(None, description="Total dielectric constant.") e_ionic: float | None = Field( - None, - description="Ionic contribution to dielectric constant.", + None, description="Ionic contribution to dielectric constant." ) - e_electronic: float | None = Field( - None, - description="Electronic contribution to dielectric constant.", + None, description="Electronic contribution to dielectric constant." ) + n: float | None = Field(None, description="Refractive index.") - n: float | None = Field( - None, - description="Refractive index.", - ) - e_ij_max: float | None = Field( - None, - description="Piezoelectric modulus.", - ) +class PiezoelectricSummary(PropModel): + _prop: str = PrivateAttr("piezoelectric") + e_ij_max: float | None = Field(None, description="Piezoelectric modulus.") - # Surface Properties +class SurfacesSummary(PropModel): + _prop: str = PrivateAttr("surfaces") weighted_surface_energy_EV_PER_ANG2: float | None = Field( - None, - description="Weighted surface energy in eV/Ų.", + None, description="Weighted surface energy in eV/Ų." ) - weighted_surface_energy: float | None = Field( - None, - description="Weighted surface energy in J/m².", + None, description="Weighted surface energy in J/m²." ) - weighted_work_function: float | None = Field( - None, - description="Weighted work function in eV.", + None, description="Weighted work function in eV." ) - surface_anisotropy: float | None = Field( - None, - description="Surface energy anisotropy.", + None, description="Surface energy anisotropy." ) - - shape_factor: float | None = Field( - None, - description="Shape factor.", - ) - + shape_factor: float | None = Field(None, description="Shape factor.") has_reconstructed: bool | None = Field( - None, - description="Whether the material has any reconstructed surfaces.", + None, description="Whether the material has any reconstructed surfaces." ) - # Oxi States +class OxiStatesSummary(PropModel): + _prop: str = PrivateAttr("oxi_states") possible_species: list[str] | None = Field( - None, - description="Possible charged species in this material.", + None, description="Possible charged species in this material." ) - # Has Props - has_props: dict[str, bool] | None = Field( - None, - description="List of properties that are available for a given material.", +class ProvenanceSummary(PropModel): + _prop: str = PrivateAttr("provenance") + theoretical: bool = Field(True, description="Whether the material is theoretical.") + database_IDs: dict[str, list[str]] | None = Field( + None, description="External database IDs corresponding to this material." ) - # Theoretical - theoretical: bool = Field( - True, - description="Whether the material is theoretical.", - ) +# ----------------------------------------------------------------------------- +# Shims for populating has_props for properties that do not add +# values to SummaryDoc +# ----------------------------------------------------------------------------- +class ChargeDensityData(PropModel): + _prop: str = PrivateAttr("charge_density") + exists: bool = False + + +class EosData(PropModel): + _prop: str = PrivateAttr("eos") + exists: bool = False + + +class PhononData(PropModel): + _prop: str = PrivateAttr("phonon") + exists: bool = False + + +class AbsorptionData(PropModel): + _prop: str = PrivateAttr("absorption") + exists: bool = False - # External Database IDs - database_IDs: dict[str, list[str]] = Field( - {}, description="External database IDs corresponding to this material." +class ElectrodesData(PropModel): + _prop: str = PrivateAttr("insertion_electrodes") + exists: bool = False + + +class SubstratesData(PropModel): + _prop: str = PrivateAttr("substrates") + exists: bool = False + + +class ChemenvData(PropModel): + _prop: str = PrivateAttr("chemenv") + exists: bool = False + + +class SummaryDoc( + MaterialsSummary, + ThermoSummary, + XASSummary, + GBSummary, + ElectronicStructureSummary, + BandstructureSummary, + DosSummary, + MagnetismSummary, + ElasticitySummary, + DielectricSummary, + PiezoelectricSummary, + SurfacesSummary, + OxiStatesSummary, + ProvenanceSummary, +): + """ + Summary information about materials and their properties, useful for materials + screening studies and searching. + """ + + has_props: dict[HasProps, bool] | None = Field( + None, + description="List of properties that are available for a given material.", ) @classmethod def from_docs( - cls, material_id: IdentifierType | None = None, **docs: dict[str, dict] - ) -> Self: - """Converts a bunch of summary docs into a SummaryDoc""" - doc = _copy_from_doc(docs) - - # Reshape document for various sub-sections - # Electronic Structure + Bandstructure + DOS - if "bandstructure" in doc: - if doc["bandstructure"] is not None and list( - filter(lambda x: x is not None, doc["bandstructure"].values()) - ): - doc["has_props"]["bandstructure"] = True - else: - del doc["bandstructure"] - if "dos" in doc: - if doc["dos"] is not None and list( - filter(lambda x: x is not None, doc["dos"].values()) - ): - doc["has_props"]["dos"] = True - else: - del doc["dos"] - if "task_id" in doc: - del doc["task_id"] - - return SummaryDoc(material_id=material_id, **doc) - - -# Key mapping -summary_fields: dict[str, list] = { - HasProps(k).value: v - for k, v in { - "materials": [ - "nsites", - "elements", - "nelements", - "composition", - "composition_reduced", - "formula_pretty", - "formula_anonymous", - "chemsys", - "volume", - "density", - "density_atomic", - "symmetry", - "structure", - "deprecated", - "task_ids", - "builder_meta", - ], - "thermo": [ - "uncorrected_energy_per_atom", - "energy_per_atom", - "formation_energy_per_atom", - "energy_above_hull", - "is_stable", - "equilibrium_reaction_energy_per_atom", - "decomposes_to", + cls, + property_summary_docs: list[ + MaterialsSummary + | ThermoSummary + | XASSummary + | GBSummary + | ElectronicStructureSummary + | BandstructureSummary + | DosSummary + | MagnetismSummary + | ElasticitySummary + | DielectricSummary + | PiezoelectricSummary + | SurfacesSummary + | OxiStatesSummary + | ProvenanceSummary ], - "xas": ["absorbing_element", "edge", "spectrum_type", "spectrum_id"], - "grain_boundaries": [ - "gb_energy", - "sigma", - "type", - "rotation_angle", - "w_sep", + property_shim_docs: list[ + ChargeDensityData + | EosData + | PhononData + | AbsorptionData + | ElectrodesData + | SubstratesData + | ChemenvData ], - "electronic_structure": [ - "band_gap", - "efermi", - "cbm", - "vbm", - "is_gap_direct", - "is_metal", - "bandstructure", - "dos", - "task_id", - ], - "magnetism": [ - "is_magnetic", - "ordering", - "total_magnetization", - "total_magnetization_normalized_vol", - "total_magnetization_normalized_formula_units", - "num_magnetic_sites", - "num_unique_magnetic_sites", - "types_of_magnetic_species", - "is_magnetic", - ], - "elasticity": [ - "bulk_modulus", - "shear_modulus", - "universal_anisotropy", - "homogeneous_poisson", - ], - "dielectric": ["e_total", "e_ionic", "e_electronic", "n"], - "piezoelectric": ["e_ij_max"], - "surface_properties": [ - "weighted_surface_energy", - "weighted_surface_energy_EV_PER_ANG2", - "shape_factor", - "surface_anisotropy", - "weighted_work_function", - "has_reconstructed", - ], - "oxi_states": ["possible_species"], - "provenance": ["theoretical", "database_IDs"], - "charge_density": [], - "eos": [], - "phonon": [], - "absorption": [], - "insertion_electrodes": [], - "substrates": [], - "chemenv": [], - }.items() -} - - -def _copy_from_doc(doc): - """Helper function to copy the list of keys over from amalgamated document""" - has_props = {str(val.value): False for val in HasProps} - d = {"has_props": has_props, "origins": []} - # Complex function to grab the keys and put them in the root doc - # if the item is a list, it makes one doc per item with those corresponding keys - for doc_key in summary_fields: - sub_doc = doc.get(doc_key, None) - if isinstance(sub_doc, list) and len(sub_doc) > 0: - d["has_props"][doc_key] = True - d[doc_key] = [] - for sub_item in sub_doc: - temp_doc = { - copy_key: sub_item[copy_key] - for copy_key in summary_fields[doc_key] - if copy_key in sub_item - } - d[doc_key].append(temp_doc) - elif isinstance(sub_doc, dict): - d["has_props"][doc_key] = True - if sub_doc.get("origins", None): - d["origins"].extend(sub_doc["origins"]) - d.update( - { - copy_key: sub_doc[copy_key] - for copy_key in summary_fields[doc_key] - if copy_key in sub_doc - } - ) - return d + **kwargs, + ) -> Self: + """ + Args: + property_summary_docs: List of propery documents with data to + be added to SummaryDoc. + property_shim_docs: List of property shim documents strictly + used to populate has_props. + """ + # initialize all has_props variants to False, overwrite according to what caller provides + has_props = {prop.value: False for prop in HasProps} + for prop in chain(property_summary_docs, property_summary_docs): + has_props[prop.name] = prop._has_props + + return SummaryDoc( + has_props=has_props, + **ChainMap(*[doc.model_dump() for doc in property_summary_docs]), + **kwargs, + ) diff --git a/emmet-core/emmet/core/tasks.py b/emmet-core/emmet/core/tasks.py index 5a42886cb8..bf99eaf264 100644 --- a/emmet-core/emmet/core/tasks.py +++ b/emmet-core/emmet/core/tasks.py @@ -316,8 +316,7 @@ class CoreTaskDoc(StructureMetadata): ) task_id: IdentifierType | None = Field( None, - description="The (task) ID of this calculation, used as a universal reference across property documents." - "This comes in the form: mp-******.", + description="The (task) ID of this calculation, used as a universal reference across property documents.", ) task_type: TaskType | CalcType | None = Field( None, description="The type of calculation." diff --git a/emmet-core/emmet/core/thermo.py b/emmet-core/emmet/core/thermo.py index 9764f69258..0f28e013b1 100644 --- a/emmet-core/emmet/core/thermo.py +++ b/emmet-core/emmet/core/thermo.py @@ -38,7 +38,7 @@ class DecompositionProduct(BaseModel): ) amount: float | None = Field( None, - description="The amount of the decomposed material by formula units this this material decomposes to.", + description="The amount of the decomposed material by formula units this material decomposes to.", ) diff --git a/emmet-core/emmet/core/types/electronic_structure.py b/emmet-core/emmet/core/types/electronic_structure.py new file mode 100644 index 0000000000..91fce0ba8c --- /dev/null +++ b/emmet-core/emmet/core/types/electronic_structure.py @@ -0,0 +1,68 @@ +from pydantic import BaseModel, Field, model_validator + +from emmet.core.types.pymatgen_types.bandstructure_symm_line_adapter import ( + BandStructureSymmLineType, +) +from emmet.core.types.pymatgen_types.dos_adapter import CompleteDosType +from emmet.core.types.typing import IdentifierType + +lmaxmix = int + + +class DosShim(BaseModel): + """Light wrapper around DOS data - useful for static analysis and runtime safety""" + + dos: tuple[IdentifierType, CompleteDosType, lmaxmix] = Field( + ..., + description="Tuple of a calculation (task) ID, a CompleteDos object, and lmaxmix from the calculation.", + ) + + +class BSShim(BaseModel): + """ + Light wrapper around bandstructure data - useful for static analysis and runtime safety + + At least one field must be populated with bandstructure data. + """ + + setyawan_curtarolo: ( + tuple[IdentifierType, BandStructureSymmLineType, lmaxmix] | None + ) = Field( + None, + description=""" + Tuple of a calculation (task) ID, a BandStructureSymmLine object + from a calculation run using the Setyawan-Curtarolo k-path + convention, and lmaxmix from the calculation. + """, + ) + hinuma: tuple[IdentifierType, BandStructureSymmLineType, lmaxmix] | None = Field( + None, + description=""" + Tuple of a calculation (task) ID, a BandStructureSymmLine object + from a calculation run using the Hinuma et al. k-path + convention, and lmaxmix from the calculation. + """, + ) + latimer_munro: tuple[IdentifierType, BandStructureSymmLineType, lmaxmix] | None = ( + Field( + None, + description=""" + Tuple of a calculation (task) ID, a BandStructureSymmLine object + from a calculation run using the Latimer-Munro et al. k-path + convention, and lmaxmix from the calculation. + """, + ) + ) + + @model_validator(mode="after") + def _has_at_least_one_bandstructure(self): + has_setyawan_curtarolo = self.setyawan_curtarolo is not None + has_hinuma = self.hinuma is not None + has_latimer_munro = self.latimer_munro is not None + + if not (has_setyawan_curtarolo or has_hinuma or has_latimer_munro): + raise ValueError()( + "At least one bandstructure type ('setyawan_curtarolo', 'hinuma', or 'latimer_munro') must be populated" + ) + + return self diff --git a/emmet-core/emmet/core/types/pymatgen_types/bandstructure_symm_line_adapter.py b/emmet-core/emmet/core/types/pymatgen_types/bandstructure_symm_line_adapter.py index 8c1a0eeba1..ab060375a1 100644 --- a/emmet-core/emmet/core/types/pymatgen_types/bandstructure_symm_line_adapter.py +++ b/emmet-core/emmet/core/types/pymatgen_types/bandstructure_symm_line_adapter.py @@ -4,12 +4,12 @@ from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine from typing_extensions import TypedDict +from emmet.core.types.pymatgen_types.kpoint_adapter import KpointType from emmet.core.types.pymatgen_types.lattice_adapter import MSONableTypedLatticeDict from emmet.core.types.pymatgen_types.structure_adapter import ( TypedStructureDict, pop_empty_structure_keys, ) -from emmet.core.types.typing import TypedBandDict class TypedBandGapDict(TypedDict): @@ -24,6 +24,16 @@ class TypedBranchDict(TypedDict): name: str +class TypedBandDict(TypedDict): + """Type def for data stored for cbms or vbms""" + + band_index: dict[str, list[int]] + kpoint_index: list[int] + kpoint: KpointType + energy: float + projections: dict[str, list[list[float]]] + + TypedBandDictureSymmLineDict = TypedDict( "TypedBandDictureSymmLineDict", { diff --git a/emmet-core/emmet/core/types/pymatgen_types/computed_entries_adapter.py b/emmet-core/emmet/core/types/pymatgen_types/computed_entries_adapter.py index 73868779bf..bde72c1833 100644 --- a/emmet-core/emmet/core/types/pymatgen_types/computed_entries_adapter.py +++ b/emmet-core/emmet/core/types/pymatgen_types/computed_entries_adapter.py @@ -1,15 +1,17 @@ from typing import Annotated, Any, TypeVar import orjson -from pydantic import BeforeValidator, WrapSerializer +from pydantic import BeforeValidator, TypeAdapter, WrapSerializer from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry from typing_extensions import NotRequired, TypedDict +from emmet.core.mpid_ext import ThermoID from emmet.core.types.pymatgen_types.element_adapter import ElementType from emmet.core.types.pymatgen_types.structure_adapter import ( - TypedStructureDict, + StructureType, pop_empty_structure_keys, ) +from emmet.core.types.typing import IdentifierType, MaterialIdentifierType from emmet.core.vasp.calculation import PotcarSpec # TypedEnergyAdjustmentDict = TypedDict( @@ -59,20 +61,20 @@ class TypedCEDataDict(TypedDict): - oxide_type: str - aspherical: bool - last_updated: str - task_id: str - material_id: str - oxidation_states: dict[ElementType, float] - license: str - run_type: str + oxide_type: NotRequired[str | None] + aspherical: NotRequired[bool | None] + last_updated: NotRequired[str | None] + task_id: NotRequired[IdentifierType | None] + material_id: NotRequired[MaterialIdentifierType | None] + oxidation_states: NotRequired[dict[ElementType, float] | None] + license: NotRequired[str | None] + run_type: NotRequired[str | None] class TypedCEParameterDict(TypedDict): - potcar_spec: list[PotcarSpec] - run_type: str - is_hubbard: bool + potcar_spec: NotRequired[list[PotcarSpec] | None] + run_type: NotRequired[str | None] + is_hubbard: NotRequired[bool | None] hubbards: NotRequired[dict[str, float] | None] # type: ignore[type-arg] @@ -83,14 +85,14 @@ class TypedCEParameterDict(TypedDict): "@class": str, "energy": float, "composition": dict[ElementType, float], - "entry_id": str, + "entry_id": ThermoID, "correction": float, # "energy_adjustments": list[ # TypedCompositionEnergyAdjustmentDict # | TypedEnergyAdjustmentDict # | TypedTemperatureEnergyAdjustmentDict # ], - "energy_adjustments": str, + "energy_adjustments": NotRequired[str | None], "parameters": TypedCEParameterDict, "data": TypedCEDataDict, }, @@ -98,7 +100,7 @@ class TypedCEParameterDict(TypedDict): class TypedComputedStructureEntryDict(TypedComputedEntryDict): - structure: TypedStructureDict + structure: StructureType ComputedEntryTypeVar = TypeVar( @@ -116,6 +118,9 @@ class TypedComputedStructureEntryDict(TypedComputedEntryDict): def entry_serializer(entry, nxt, info) -> dict[str, Any]: + # need to beat pmg serialization to get correct id serialization + entry.data = TypeAdapter(TypedCEDataDict).dump_python(entry.data) + default_serialized_object = nxt(entry.as_dict(), info) format = info.context.get("format") if info.context else None @@ -142,11 +147,18 @@ def entry_deserializer(entry: dict[str, Any] | ComputedEntry | ComputedStructure match entry["@class"]: case "ComputedEntry": entry_cls = ComputedEntry + entry_type = TypedComputedEntryDict case "ComputedStructureEntry": entry_cls = ComputedStructureEntry + entry_type = TypedComputedStructureEntryDict entry = pop_cse_empty_keys(entry) - if isinstance(entry["energy_adjustments"], str): + # must be before 'energy_adjustments' deserialization + entry = TypeAdapter(entry_type).validate_python(entry) + + if "energy_adjustments" in entry and isinstance( + entry["energy_adjustments"], str + ): entry["energy_adjustments"] = orjson.loads(entry["energy_adjustments"]) return entry_cls.from_dict(entry) # type: ignore[arg-type] diff --git a/emmet-core/emmet/core/types/pymatgen_types/phase_diagram_adapter.py b/emmet-core/emmet/core/types/pymatgen_types/phase_diagram_adapter.py index 4701b9debd..4b5ae8932d 100644 --- a/emmet-core/emmet/core/types/pymatgen_types/phase_diagram_adapter.py +++ b/emmet-core/emmet/core/types/pymatgen_types/phase_diagram_adapter.py @@ -101,9 +101,15 @@ def phase_diagram_serializer(phase_diagram, nxt, info) -> dict[str, Any]: entry.as_dict() for entry in default_serialized_object["computed_data"][key] ] + # ndarray -> list[list[int]] for simplex in default_serialized_object["computed_data"]["simplexes"]: simplex["coords"] = simplex["coords"].tolist() + # ndarray -> list[list[float]] + default_serialized_object["computed_data"]["qhull_data"] = ( + default_serialized_object["computed_data"]["qhull_data"].tolist() + ) + format = info.context.get("format") if info.context else None if format == "arrow": phase_diagram_serde( diff --git a/emmet-core/emmet/core/types/typing.py b/emmet-core/emmet/core/types/typing.py index 81177e8a76..a619fad0ca 100644 --- a/emmet-core/emmet/core/types/typing.py +++ b/emmet-core/emmet/core/types/typing.py @@ -11,15 +11,14 @@ import os from datetime import datetime +from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Union, Any +from typing import TYPE_CHECKING, Annotated, Any, Union import orjson from pydantic import BeforeValidator, Field, PlainSerializer, WrapSerializer -from typing_extensions import TypedDict from emmet.core.mpid import MPID, AlphaID -from emmet.core.types.pymatgen_types.kpoint_adapter import KpointType from emmet.core.utils import convert_datetime, utcnow if TYPE_CHECKING: @@ -52,33 +51,42 @@ """ -def _fault_tolerant_id_serde(val: Any, serialize: bool = False) -> Any: +def _fault_tolerant_id_serde( + val: Any, + legacy: bool = False, + serialize: bool = False, + **kwargs, +) -> Any: """Needed for the API and safe de-/serialization behavior.""" try: - alpha_id = AlphaID(val) + alpha_id = AlphaID(val, **kwargs) if serialize: return str(alpha_id) - return alpha_id.formatted + return alpha_id.formatted if legacy else alpha_id except Exception: return val -IdentifierType: TypeAlias = Annotated[ - Union[MPID, AlphaID], - BeforeValidator(_fault_tolerant_id_serde), - PlainSerializer(lambda x: _fault_tolerant_id_serde(x, serialize=True)), -] -"""MPID / AlphaID serde.""" +_id_base_metadata = (BeforeValidator(_fault_tolerant_id_serde),) + +def _make_id_type(render_order, **kwargs) -> Annotated: + match render_order: + case 0: + _order = Union[AlphaID, MPID] + case 1: + _order = Union[MPID, AlphaID] -class TypedBandDict(TypedDict): - """Type def for data stored for cbms or vbms""" + return Annotated[ + _order, + BeforeValidator(partial(_fault_tolerant_id_serde, **kwargs)), + PlainSerializer(partial(_fault_tolerant_id_serde, serialize=True, **kwargs)), + ] - band_index: dict[str, list[int]] - kpoint_index: list[int] - kpoint: KpointType - energy: float - projections: dict[str, list[list[float]]] + +IdentifierType: TypeAlias = _make_id_type(0, padlen=8) +MaterialIdentifierType: TypeAlias = _make_id_type(1, legacy=True, prefix="mp", padlen=8) +"""MPID / AlphaID serde.""" def _ser_json_like(d, default_serializer, info): diff --git a/emmet-core/emmet/core/vasp/material.py b/emmet-core/emmet/core/vasp/material.py index 2bad9e771a..ea72753a4e 100644 --- a/emmet-core/emmet/core/vasp/material.py +++ b/emmet-core/emmet/core/vasp/material.py @@ -1,15 +1,16 @@ """Core definition of a Materials Document""" -from typing import Mapping +from typing import Mapping, Self from pydantic import BaseModel, Field from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer, oxide_type from pymatgen.analysis.structure_matcher import StructureMatcher -from pymatgen.entries.computed_entries import ComputedStructureEntry from emmet.core.base import EmmetMeta from emmet.core.material import MaterialsDoc as CoreMaterialsDoc from emmet.core.material import PropertyOrigin +from emmet.core.mpid import AlphaID +from emmet.core.mpid_ext import ThermoID from emmet.core.settings import EmmetSettings from emmet.core.tasks import TaskDoc from emmet.core.types.pymatgen_types.computed_entries_adapter import ( @@ -62,7 +63,7 @@ def from_tasks( ] = SETTINGS.VASP_STRUCTURE_QUALITY_SCORES, use_statics: bool = SETTINGS.VASP_USE_STATICS, commercial_license: bool = True, - ) -> "MaterialsDoc": + ) -> Self: """ Converts a group of tasks into one material @@ -105,7 +106,7 @@ def from_tasks( if use_statics: possible_mat_ids += [task.task_id for task in statics] - material_id = min(possible_mat_ids) + material_id = AlphaID(min(possible_mat_ids), prefix="mp") # Always prefer a static over a structure opt structure_task_quality_scores = {"Structure Optimization": 1, "Static": 2} @@ -207,10 +208,12 @@ def _entry_eval(task: TaskDoc): if relevant_calcs: best_task_doc = relevant_calcs[0] - entry = ComputedStructureEntry( - composition=best_task_doc.output.structure.composition, - correction=0.0, - data={ + entry = { + "@class": "ComputedStructureEntry", + "@module": "pymatgen.entries.computed_entries", + "composition": best_task_doc.output.structure.composition, + "correction": 0.0, + "data": { "aspherical": best_task_doc.input.parameters.get( "LASPH", False ), @@ -219,9 +222,9 @@ def _entry_eval(task: TaskDoc): "material_id": material_id, "task_id": best_task_doc.task_id, }, - energy=best_task_doc.output.energy, - entry_id="{}-{}".format(material_id, rt.value), - parameters={ + "energy": best_task_doc.output.energy, + "entry_id": ThermoID(identifier=material_id, suffix=rt), + "parameters": { "hubbards": best_task_doc.input.hubbards, "is_hubbard": best_task_doc.input.is_hubbard, "potcar_spec": ( @@ -231,8 +234,8 @@ def _entry_eval(task: TaskDoc): ), "run_type": str(best_task_doc.run_type), }, - structure=best_task_doc.output.structure, - ) + "structure": best_task_doc.output.structure, + } entries[rt] = entry if not any( diff --git a/emmet-core/emmet/core/vasp/validation_legacy.py b/emmet-core/emmet/core/vasp/validation_legacy.py index d643c9fe62..76a5bbf9b2 100644 --- a/emmet-core/emmet/core/vasp/validation_legacy.py +++ b/emmet-core/emmet/core/vasp/validation_legacy.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self import numpy as np from pydantic import BaseModel, Field, ImportString @@ -55,27 +55,20 @@ class ValidationDoc(EmmetBaseModel, extra="allow"): description="Dictioary of data used to perform validation." " Useful for post-mortem analysis", ) - nelements: int | None = Field(None, description="Number of elements.") - symmetry_number: int | None = Field( - None, - title="Space Group Number", - description="The spacegroup number for the lattice.", - ) - chemsys: str | None = Field(None) - formula_pretty: str | None = Field(None) @classmethod def from_task_doc( cls, task_doc: CoreTaskDoc | TaskDoc | TaskDocument, - kpts_tolerance: float = SETTINGS.VASP_KPTS_TOLERANCE, + kpts_tolerance: float | None = None, kspacing_tolerance: float = SETTINGS.VASP_KSPACING_TOLERANCE, input_sets: dict[str, ImportString] = SETTINGS.VASP_DEFAULT_INPUT_SETS, LDAU_fields: list[str] = SETTINGS.VASP_CHECKED_LDAU_FIELDS, max_allowed_scf_gradient: float = SETTINGS.VASP_MAX_SCF_GRADIENT, max_magmoms: dict[str, float] = SETTINGS.VASP_MAX_MAGMOM, potcar_stats: dict[CalcType, dict[str, str]] | None = None, - ) -> "ValidationDoc": + bad_tags: list[str] | None = None, + ) -> Self: """ Determines if a calculation is valid based on expected input parameters from a pymatgen inputset @@ -89,17 +82,16 @@ def from_task_doc( max_allowed_scf_gradient: maximum uphill gradient allowed for SCF steps after the initial equillibriation period potcar_stats: Dictionary of potcar stat data. Mapping is calculation type -> potcar symbol -> hash value. + bad_tags: List of tags for calculations to deprecate """ - - nelements = task_doc.nelements or None - symmetry_number = task_doc.symmetry.number if task_doc.symmetry else None + if not kpts_tolerance: + kpts_tolerance = 0.4 if "mp_production_old" in task_doc.tags else 0.9 bandgap = task_doc.output.bandgap calc_type = task_doc.calc_type task_type = task_doc.task_type run_type = task_doc.run_type chemsys = task_doc.chemsys - formula_pretty = task_doc.formula_pretty if isinstance(task_doc, (TaskDoc, TaskDocument)): inputs = task_doc.orig_inputs @@ -227,18 +219,16 @@ def from_task_doc( else: reasons.append(DeprecationMessage.SET) + if len(list(set(task_doc.tags).intersection(bad_tags))) > 0: + warnings.append(f"Manual Deprecation by tags: {bad_tags}") + reasons.append(DeprecationMessage.MANUAL) + doc = ValidationDoc( task_id=task_doc.task_id, - calc_type=calc_type, - run_type=task_doc.run_type, valid=len(reasons) == 0, reasons=reasons, data=data, warnings=warnings, - nelements=nelements, - symmetry_number=symmetry_number, - chemsys=chemsys, - formula_pretty=formula_pretty, ) return doc @@ -247,7 +237,9 @@ def from_task_doc( def _get_input_set(run_type, task_type, calc_type, structure, input_sets, bandgap): # Ensure inputsets get proper additional input values if "SCAN" in run_type.value: - valid_input_set: VaspInputSet = input_sets[str(calc_type)](structure, bandgap=bandgap) # type: ignore + valid_input_set: VaspInputSet = input_sets[str(calc_type)]( + structure, bandgap=bandgap + ) # type: ignore elif task_type == TaskType.NSCF_Uniform or task_type == TaskType.NSCF_Line: # Constructing the k-path for line-mode calculations is too costly, so # the uniform input set is used instead and k-points are not checked. diff --git a/emmet-core/pyproject.toml b/emmet-core/pyproject.toml index b671171fb8..a46998fd57 100644 --- a/emmet-core/pyproject.toml +++ b/emmet-core/pyproject.toml @@ -52,6 +52,7 @@ all = [ "pymatgen-analysis-diffusion>=2024.7.15", "pymatgen-analysis-alloys>=0.0.6", "pyarrow", + "pycodcif", ] test = [ "pre-commit", diff --git a/emmet-core/tests/test_model_fields.py b/emmet-core/tests/test_model_fields.py index eeae509150..bb9f212bbf 100644 --- a/emmet-core/tests/test_model_fields.py +++ b/emmet-core/tests/test_model_fields.py @@ -1,8 +1,8 @@ """Ensure that document models used in API do not change fields.""" -import pytest from importlib import import_module +import pytest ref_model_fields = { "emmet.core._general_store.GeneralStoreDoc": [ @@ -964,7 +964,6 @@ "efermi", "is_gap_direct", "is_metal", - "es_source_calc_id", "bandstructure", "dos", "dos_energy_up",