From eec3acd3bc1a0138da5f0a66843be2df8c5f8604 Mon Sep 17 00:00:00 2001 From: Calvin Pieters Date: Fri, 27 Mar 2026 19:34:52 +0300 Subject: [PATCH 01/14] Initial work --- arc/plotter.py | 153 ++++++++++++++++++++++++++++++++++++++++++++ arc/plotter_test.py | 34 ++++++++++ arc/scheduler.py | 128 +++++++++++++++++++++++++++++++++++- 3 files changed, 313 insertions(+), 2 deletions(-) diff --git a/arc/plotter.py b/arc/plotter.py index d0f6938e84..3ef287b166 100644 --- a/arc/plotter.py +++ b/arc/plotter.py @@ -2,6 +2,7 @@ A module for plotting and saving output files such as RMG libraries. """ +import datetime import matplotlib # Force matplotlib to not use any Xwindows backend. # This must be called before pylab, matplotlib.pyplot, or matplotlib.backends is imported. @@ -12,10 +13,15 @@ import numpy as np import os import shutil +import textwrap from matplotlib.backends.backend_pdf import PdfPages from mpl_toolkits.mplot3d import Axes3D from typing import List, Optional, Tuple, Union +try: + import graphviz +except ImportError: + graphviz = None import py3Dmol as p3D from rdkit import Chem @@ -54,6 +60,153 @@ logger = get_logger() +def _sanitize_graphviz_id(value: str) -> str: + """Return a Graphviz-safe identifier.""" + return ''.join(ch if ch.isalnum() else '_' for ch in value) + + +def _wrap_graph_label(text: str, width: int = 24) -> str: + """Wrap long labels so graph nodes stay readable.""" + return '\n'.join(textwrap.wrap(str(text), width=width)) if text else '' + + +def save_provenance_artifacts(project_directory: str, + provenance: dict, + ) -> dict: + """ + Save provenance YAML and render Graphviz artifacts for an ARC run. + + Args: + project_directory (str): The ARC project directory. + provenance (dict): A provenance dictionary with an ``events`` list. + + Returns: + dict: Paths to generated artifacts. + """ + output_directory = os.path.join(project_directory, 'output') + os.makedirs(output_directory, exist_ok=True) + yml_path = os.path.join(output_directory, 'provenance.yml') + dot_path = os.path.join(output_directory, 'provenance.dot') + svg_path = os.path.join(output_directory, 'provenance.svg') + + save_yaml_file(path=yml_path, content=provenance) + + run_label = provenance.get('project', 'ARC run') + if graphviz is None: + logger.warning('The graphviz Python package is not available, so ARC will only save provenance.yml.') + return {'yml': yml_path, 'dot': None, 'svg': None} + + graph = graphviz.Digraph( + name='arc_provenance', + comment=f'ARC provenance for {run_label}', + graph_attr={'rankdir': 'LR', 'splines': 'true', 'overlap': 'false'}, + node_attr={'shape': 'box', 'style': 'rounded,filled', 'fillcolor': 'white', 'fontname': 'Helvetica'}, + edge_attr={'fontname': 'Helvetica'}, + ) + run_node_id = _sanitize_graphviz_id(f"run_{provenance.get('run_id', run_label)}") + run_header = provenance.get('started_at', '') + run_footer = provenance.get('ended_at', '') + run_text = f'{run_label}' + if run_header: + run_text += f'\nstart: {run_header}' + if run_footer: + run_text += f'\nend: {run_footer}' + graph.node(run_node_id, _wrap_graph_label(run_text, width=32), shape='oval', fillcolor='lightgoldenrod1') + + species_nodes, job_nodes = dict(), dict() + last_node_by_label = dict() + + for event in provenance.get('events', list()): + event_type = event.get('event_type', '') + label = event.get('label') + if label and label not in species_nodes: + species_node_id = _sanitize_graphviz_id(f'species_{label}') + species_text = label + if event.get('is_ts'): + species_text += '\nTS' + graph.node(species_node_id, _wrap_graph_label(species_text), fillcolor='aliceblue') + graph.edge(run_node_id, species_node_id) + species_nodes[label] = species_node_id + last_node_by_label[label] = species_node_id + + if event_type == 'job_started': + job_key = event.get('job_key', event.get('job_name', 'job')) + job_node_id = _sanitize_graphviz_id(f'job_{job_key}') + job_text = f"{event.get('job_type', 'job')}\n{event.get('job_name', job_key)}" + if event.get('job_adapter'): + job_text += f"\n{event['job_adapter']}" + if event.get('level'): + job_text += f"\n{event['level']}" + graph.node(job_node_id, _wrap_graph_label(job_text), fillcolor='white') + source_node_id = run_node_id if label is None else last_node_by_label.get(label, species_nodes.get(label)) + if source_node_id is not None: + edge_label = event.get('provenance_reason') or '' + graph.edge(source_node_id, job_node_id, label=edge_label) + if label is not None: + last_node_by_label[label] = job_node_id + job_nodes[job_key] = job_node_id + + elif event_type == 'job_finished': + job_key = event.get('job_key') + if job_key in job_nodes: + status = event.get('status', 'unknown') + fillcolor = {'done': 'honeydew', 'errored': 'mistyrose'}.get(status, 'lightyellow') + graph.node(job_nodes[job_key], fillcolor=fillcolor) + + result_node_id = _sanitize_graphviz_id( + f"result_{event.get('event_id', len(job_nodes))}_{job_key}" + ) + result_text = f"{status}" + if event.get('run_time'): + result_text += f"\n{event['run_time']}" + if event.get('keywords'): + result_text += f"\n{', '.join(event['keywords'])}" + graph.node(result_node_id, _wrap_graph_label(result_text), shape='note', fillcolor='cornsilk') + graph.edge(job_nodes[job_key], result_node_id) + if label is not None: + last_node_by_label[label] = result_node_id + + elif event_type in ['ts_guess_selected', 'job_troubleshooting']: + decision_node_id = _sanitize_graphviz_id(f"decision_{event.get('event_id', 0)}") + if event_type == 'ts_guess_selected': + decision_text = f"Select TS guess {event.get('selected_index')}" + if event.get('method'): + decision_text += f"\n{event['method']}" + fillcolor = 'lavender' + else: + decision_text = f"Troubleshoot {event.get('job_name', '')}" + if event.get('methods'): + decision_text += f"\n{', '.join(event['methods'])}" + fillcolor = 'moccasin' + graph.node(decision_node_id, _wrap_graph_label(decision_text), shape='diamond', fillcolor=fillcolor) + source_job_key = event.get('job_key') + source_node_id = job_nodes.get(source_job_key) if source_job_key else last_node_by_label.get(label) + if source_node_id is None and label is not None: + source_node_id = species_nodes.get(label) + if source_node_id is not None: + graph.edge(source_node_id, decision_node_id) + if label is not None: + last_node_by_label[label] = decision_node_id + + elif event_type == 'species_initialized' and label in species_nodes: + continue + + with open(dot_path, 'w') as f: + f.write(graph.source) + + try: + svg_data = graph.pipe(format='svg') + except (graphviz.ExecutableNotFound, graphviz.CalledProcessError): + logger.warning('Could not render ARC provenance SVG because Graphviz is not available on this system.') + else: + with open(svg_path, 'wb') as f: + f.write(svg_data) + + provenance['updated_at'] = datetime.datetime.now().isoformat(timespec='seconds') + save_yaml_file(path=yml_path, content=provenance) + return {'yml': yml_path, 'dot': dot_path, 'svg': svg_path if os.path.isfile(svg_path) else None} + + # *** Drawings species *** def draw_structure(xyz=None, species=None, project_directory=None, method='show_sticks', show_atom_indices=False): diff --git a/arc/plotter_test.py b/arc/plotter_test.py index ba6984dae4..0a300f6ab6 100644 --- a/arc/plotter_test.py +++ b/arc/plotter_test.py @@ -218,6 +218,40 @@ def test_save_irc_traj_animation(self): plotter.save_irc_traj_animation(irc_f_path, irc_r_path, out_path) self.assertTrue(os.path.isfile(out_path)) + def test_save_provenance_artifacts(self): + """Test saving ARC provenance YAML / Graphviz artifacts.""" + project = 'arc_project_for_testing_delete_after_usage' + project_directory = os.path.join(ARC_PATH, 'Projects', project) + provenance = { + 'project': project, + 'run_id': 'run_1', + 'started_at': '2026-03-15T10:00:00', + 'ended_at': '2026-03-15T10:05:00', + 'events': [ + {'event_id': 1, 'event_type': 'species_initialized', 'timestamp': '2026-03-15T10:00:00', 'label': 'spc1'}, + {'event_id': 2, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:00:01', + 'label': 'spc1', 'job_key': 'spc1:opt_a1', 'job_name': 'opt_a1', 'job_type': 'opt', + 'job_adapter': 'gaussian', 'level': 'b3lyp/6-31g(d)'}, + {'event_id': 3, 'event_type': 'job_finished', 'timestamp': '2026-03-15T10:01:00', + 'label': 'spc1', 'job_key': 'spc1:opt_a1', 'job_name': 'opt_a1', 'job_type': 'opt', + 'status': 'done', 'run_time': '0:01:00'}, + {'event_id': 4, 'event_type': 'job_troubleshooting', 'timestamp': '2026-03-15T10:01:05', + 'label': 'spc1', 'job_key': 'spc1:freq_a2', 'job_name': 'freq_a2', 'job_type': 'freq', + 'methods': ['memory']}, + {'event_id': 5, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:01:10', + 'label': 'spc1', 'job_key': 'spc1:freq_a3', 'job_name': 'freq_a3', 'job_type': 'freq', + 'job_adapter': 'gaussian', 'provenance_reason': 'ess_troubleshoot'}, + ], + } + paths = plotter.save_provenance_artifacts(project_directory=project_directory, provenance=provenance) + self.assertTrue(os.path.isfile(paths['yml'])) + if paths['dot'] is not None: + self.assertTrue(os.path.isfile(paths['dot'])) + with open(paths['dot'], 'r') as f: + dot = f.read() + self.assertIn('spc1', dot) + self.assertIn('opt_a1', dot) + @classmethod def tearDownClass(cls): diff --git a/arc/scheduler.py b/arc/scheduler.py index 0b4ed71762..596f959755 100644 --- a/arc/scheduler.py +++ b/arc/scheduler.py @@ -9,6 +9,7 @@ import pprint import shutil import time +from typing import Any import numpy as np from typing import TYPE_CHECKING, List, Optional, Tuple, Union @@ -297,12 +298,20 @@ def __init__(self, self.output_multi_spc = dict() self.report_e_elect = report_e_elect self.skip_nmd = skip_nmd + self.provenance = {'version': 1, + 'project': self.project, + 'run_id': f'{self.project}_{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}', + 'started_at': datetime.datetime.now().isoformat(timespec='seconds'), + 'events': list(), + } + self.provenance_path = os.path.join(self.project_directory, 'output', 'provenance.yml') self.species_dict, self.rxn_dict = dict(), dict() for species in self.species_list: self.species_dict[species.label] = species for rxn in self.rxn_list: self.rxn_dict[rxn.index] = rxn + self._initialize_provenance() if self.restart_dict is not None: self.output = self.restart_dict['output'] if 'output' in self.restart_dict else dict() self.output_multi_spc = self.restart_dict['output_multi_spc'] if 'output_multi_spc' in self.restart_dict else dict() @@ -510,6 +519,55 @@ def __init__(self, if not self.testing: self.schedule_jobs() + def _initialize_provenance(self): + """Load previous provenance when restarting and record the current run start.""" + if os.path.isfile(self.provenance_path): + try: + provenance = read_yaml_file(self.provenance_path) + except Exception: + provenance = None + if isinstance(provenance, dict): + events = provenance.get('events', list()) + self.provenance.update({key: val for key, val in provenance.items() if key != 'events'}) + self.provenance['events'] = events + for species in self.species_list: + self.record_provenance_event(event_type='species_initialized', + label=species.label, + is_ts=species.is_ts, + ) + + def record_provenance_event(self, + event_type: str, + label: Optional[str] = None, + **data: Any, + ): + """Append a provenance event and persist the event log.""" + event = {'event_id': len(self.provenance['events']) + 1, + 'event_type': event_type, + 'timestamp': datetime.datetime.now().isoformat(timespec='seconds'), + } + if label is not None: + event['label'] = label + for key, value in data.items(): + if value is not None and value != '' and value != list(): + event[key] = value + self.provenance['events'].append(event) + self.save_provenance() + + def save_provenance(self): + """Persist the provenance event log.""" + output_directory = os.path.dirname(self.provenance_path) + if not os.path.isdir(output_directory): + os.makedirs(output_directory) + save_yaml_file(path=self.provenance_path, content=self.provenance) + + def finalize_provenance(self): + """Render final provenance artifacts after the run completes.""" + self.provenance['ended_at'] = datetime.datetime.now().isoformat(timespec='seconds') + plotter.save_provenance_artifacts(project_directory=self.project_directory, + provenance=self.provenance, + ) + def schedule_jobs(self): """ The main job scheduling block @@ -741,6 +799,7 @@ def schedule_jobs(self): # Generate a TS report: self.generate_final_ts_guess_report() + self.finalize_provenance() def run_job(self, job_type: str, @@ -767,6 +826,8 @@ def run_job(self, torsions: Optional[List[List[int]]] = None, times_rerun: int = 0, tsg: Optional[int] = None, + provenance_parent_job: Optional[str] = None, + provenance_reason: Optional[str] = None, xyz: Optional[Union[dict, List[dict]]]= None, ): """ @@ -898,6 +959,23 @@ def run_job(self, if job.server is not None and job.server not in self.servers: self.servers.append(job.server) self.check_max_simultaneous_jobs_limit(job.server) + level_repr = None if job.level is None else str(job.level) + self.record_provenance_event( + event_type='job_started', + label=label, + is_ts=self.species_dict[label].is_ts if isinstance(label, str) and label in self.species_dict else None, + job_key=f'{label}:{job.job_name}', + job_name=job.job_name, + job_type=job.job_type, + job_adapter=job.job_adapter, + level=level_repr, + execution_type=job.execution_type, + ess_trsh_methods=job.ess_trsh_methods, + conformer=conformer, + tsg=tsg, + provenance_parent_job=provenance_parent_job, + provenance_reason=provenance_reason, + ) job.execute() self.save_restart_dict() @@ -1018,6 +1096,18 @@ def end_job(self, job: 'JobAdapter', self.timer = False job.write_completed_job_to_csv_file() logger.info(f' Ending job {job_name} for {label} (run time: {job.run_time})') + self.record_provenance_event( + event_type='job_finished', + label=label, + is_ts=self.species_dict[label].is_ts if label in self.species_dict else None, + job_key=f'{label}:{job.job_name}', + job_name=job.job_name, + job_type=job.job_type, + status=job.job_status[1]['status'] if job.job_status[1]['status'] else job.job_status[0], + keywords=job.job_status[1]['keywords'], + error=job.job_status[1]['error'], + run_time=str(job.run_time) if job.run_time is not None else None, + ) if job.job_status[0] != 'done': return False if job.job_adapter in ['gaussian', 'terachem'] and os.path.isfile(os.path.join(job.local_path, 'check.chk')) \ @@ -1074,6 +1164,8 @@ def _run_a_job(self, torsions=job.torsions, times_rerun=job.times_rerun + int(rerun), tsg=job.tsg, + provenance_parent_job=job.job_name, + provenance_reason='rerun', xyz=job.xyz, ) @@ -1972,8 +2064,12 @@ def parse_conformer(self, logger.warning(f'Conformer {i} for {label} did not converge.') if job.job_status[1]['status'] == 'errored' and job.times_rerun == 0: job.times_rerun += 1 - self.troubleshoot_ess(label=label, job=job, level_of_theory=job.level, conformer= job.conformer if job.conformer is not None else None) - return True + self.troubleshoot_ess(label=label, + job=job, + level_of_theory=job.level, + conformer=job.conformer if job.conformer is not None else None) + # Report "still troubleshooting" only if another job was actually queued. + return label in self.running_jobs and job.job_name in self.running_jobs[label] if job.times_rerun == 0 and self.trsh_ess_jobs: self._run_a_job(job=job, label=label, rerun=True) return True @@ -2186,6 +2282,10 @@ def determine_most_likely_ts_conformer(self, label: str): logger.warning(f'Could not determine a likely TS conformer for {label}') self.species_dict[label].ts_number, self.species_dict[label].chosen_ts = None, None self.species_dict[label].populate_ts_checks() + self.record_provenance_event(event_type='ts_guess_selection_failed', + label=label, + is_ts=True, + ) return None else: rxn_txt = '' if self.species_dict[label].rxn_label is None \ @@ -2203,6 +2303,13 @@ def determine_most_likely_ts_conformer(self, label: str): self.species_dict[label].initial_xyz = tsg.opt_xyz self.species_dict[label].final_xyz = None self.species_dict[label].ts_guesses_exhausted = False + self.record_provenance_event(event_type='ts_guess_selected', + label=label, + is_ts=True, + selected_index=selected_i, + method=tsg.method, + energy=tsg.energy, + ) if tsg.success and tsg.energy is not None: # guess method and ts_level opt were both successful tsg.energy -= e_min im_freqs = f', imaginary frequencies {tsg.imaginary_freqs}' if tsg.imaginary_freqs is not None else '' @@ -3446,6 +3553,16 @@ def troubleshoot_ess(self, job.ess_trsh_methods = ess_trsh_methods if not couldnt_trsh: + self.record_provenance_event(event_type='job_troubleshooting', + label=label, + is_ts=self.species_dict[label].is_ts, + job_key=f'{label}:{job.job_name}', + job_name=job.job_name, + job_type=job.job_type, + methods=ess_trsh_methods, + keywords=job.job_status[1]['keywords'], + error=job.job_status[1]['error'], + ) self.run_job(label=label, xyz=xyz, level_of_theory=level_of_theory, @@ -3462,8 +3579,15 @@ def troubleshoot_ess(self, rotor_index=job.rotor_index, cpu_cores=cpu_cores, shift=shift, + provenance_parent_job=job.job_name, + provenance_reason='ess_troubleshoot', ) elif self.species_dict[label].is_ts and not self.species_dict[label].ts_guesses_exhausted: + # During TS conf_opt screening, avoid switching mid-batch since switch_ts() deletes all + # running jobs for this TS label and can discard other viable TS guesses still running. + if job.job_type == 'conf_opt': + self.save_restart_dict() + return None logger.info(f'TS {label} did not converge. ' f'Status is:\n{self.species_dict[label].ts_checks}\n' f'Searching for a better TS conformer...') From c72df97fae5588f0467a0504cdfa09a746eec596 Mon Sep 17 00:00:00 2001 From: Calvin Pieters Date: Sat, 28 Mar 2026 15:57:49 +0300 Subject: [PATCH 02/14] Added graphviz to environment --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index 5f22a9c40a..1ac6654f4f 100644 --- a/environment.yml +++ b/environment.yml @@ -24,6 +24,7 @@ dependencies: - conda-forge::ffmpeg - conda-forge::gprof2dot - conda-forge::graphviz + - conda-forge::python-graphviz - conda-forge::h5py - conda-forge::ipython - conda-forge::jupyter From 15da793233dd513b239dfe47b139d757043132aa Mon Sep 17 00:00:00 2001 From: Calvin Pieters Date: Sat, 28 Mar 2026 15:58:17 +0300 Subject: [PATCH 03/14] Enhance provenance tracking and restart consistency in the Scheduler - Improve provenance logging by avoiding duplicate initialization events and handling potentially corrupted provenance files. - Ensure internal consistency on restart by verifying that species marked as converged have all required output paths, resetting their status otherwise. - Fix job key generation for reactions (lists of labels) and improve tracking for running conformer jobs. - Defer TS switching during conformer optimization batches to avoid unnecessary job deletions. --- arc/scheduler.py | 108 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 92 insertions(+), 16 deletions(-) diff --git a/arc/scheduler.py b/arc/scheduler.py index efc1876d16..536ecfdd61 100644 --- a/arc/scheduler.py +++ b/arc/scheduler.py @@ -9,10 +9,8 @@ import pprint import shutil import time -from typing import Any - import numpy as np -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import Any, TYPE_CHECKING, List, Optional, Tuple, Union import arc.parser.parser as parser from arc import plotter @@ -334,6 +332,8 @@ def __init__(self, self.orbitals_level = orbitals_level self.unique_species_labels = list() self.save_restart = False + if self.restart_dict is not None: + self._sanitize_restart_output() if len(self.rxn_list): rxn_info_path = self.make_reaction_labels_info_file() @@ -525,16 +525,18 @@ def _initialize_provenance(self): try: provenance = read_yaml_file(self.provenance_path) except Exception: + logger.warning('Could not parse existing provenance.yml; starting a fresh provenance log.') provenance = None if isinstance(provenance, dict): - events = provenance.get('events', list()) - self.provenance.update({key: val for key, val in provenance.items() if key != 'events'}) - self.provenance['events'] = events + self.provenance['events'] = provenance.get('events', list()) + already_initialized = {e['label'] for e in self.provenance['events'] + if e.get('event_type') == 'species_initialized' and 'label' in e} for species in self.species_list: - self.record_provenance_event(event_type='species_initialized', - label=species.label, - is_ts=species.is_ts, - ) + if species.label not in already_initialized: + self.record_provenance_event(event_type='species_initialized', + label=species.label, + is_ts=species.is_ts, + ) def record_provenance_event(self, event_type: str, @@ -856,6 +858,8 @@ def run_job(self, torsions (List[List[int]], optional): The 0-indexed atom indices of the torsion(s). trsh (str, optional): A troubleshooting keyword to be used in input files. tsg (int, optional): TSGuess number if optimizing TS guesses. + provenance_parent_job (str, optional): The job_name of the parent job that triggered this one. + provenance_reason (str, optional): Why this job was spawned (e.g., 'rerun', 'ess_troubleshoot', 'fine_opt'). xyz (Union[dict, List[dict]], optional): The 3D coordinates for the species. """ max_job_time = max_job_time or self.max_job_time # if it's None, set to default @@ -960,11 +964,12 @@ def run_job(self, self.servers.append(job.server) self.check_max_simultaneous_jobs_limit(job.server) level_repr = None if job.level is None else str(job.level) + provenance_label = '+'.join(label) if isinstance(label, list) else label self.record_provenance_event( event_type='job_started', - label=label, + label=provenance_label, is_ts=self.species_dict[label].is_ts if isinstance(label, str) and label in self.species_dict else None, - job_key=f'{label}:{job.job_name}', + job_key=f'{provenance_label}:{job.job_name}', job_name=job.job_name, job_type=job.job_type, job_adapter=job.job_adapter, @@ -1099,7 +1104,7 @@ def end_job(self, job: 'JobAdapter', self.record_provenance_event( event_type='job_finished', label=label, - is_ts=self.species_dict[label].is_ts if label in self.species_dict else None, + is_ts=self.species_dict[label].is_ts if isinstance(label, str) and label in self.species_dict else None, job_key=f'{label}:{job.job_name}', job_name=job.job_name, job_type=job.job_type, @@ -2069,7 +2074,9 @@ def parse_conformer(self, level_of_theory=job.level, conformer=job.conformer if job.conformer is not None else None) # Report "still troubleshooting" only if another job was actually queued. - return label in self.running_jobs and job.job_name in self.running_jobs[label] + # Conformer jobs are tracked in running_jobs as '{job_type}_{conformer}', not by job_name. + running_key = f'{job.job_type}_{job.conformer}' if job.conformer is not None else job.job_name + return label in self.running_jobs and running_key in self.running_jobs[label] if job.times_rerun == 0 and self.trsh_ess_jobs: self._run_a_job(job=job, label=label, rerun=True) return True @@ -2484,6 +2491,8 @@ def parse_opt_geo(self, level_of_theory=job.level, job_type='opt', fine=True, + provenance_parent_job=job.job_name, + provenance_reason='fine_opt', ) else: success = True @@ -2726,7 +2735,6 @@ def switch_ts(self, label: str): logger.info(f'Switching a TS guess for {label}...') self.determine_most_likely_ts_conformer(label=label) # Look for a different TS guess. self.delete_all_species_jobs(label=label) # Delete other currently running jobs for this TS. - self.output[label]['geo'] = self.output[label]['freq'] = self.output[label]['sp'] = self.output[label]['composite'] = '' freq_path = os.path.join(self.project_directory, 'output', 'rxns', label, 'geometry', 'freq.out') if os.path.isfile(freq_path): os.remove(freq_path) @@ -3151,6 +3159,9 @@ def check_all_done(self, label: str): logger.debug(f'Species {label} did not converge.') all_converged = False break + if all_converged and self._missing_required_paths(label): + logger.debug(f'Species {label} did not converge due to missing output paths.') + all_converged = False if label in self.output and all_converged: self.output[label]['convergence'] = True if self.species_dict[label].is_ts: @@ -3191,6 +3202,64 @@ def check_all_done(self, label: str): # Update restart dictionary and save the yaml restart file: self.save_restart_dict() + def _missing_required_paths(self, label: str) -> bool: + """ + Check whether required output paths are missing for a species/TS. + + Args: + label (str): The species label. + + Returns: + bool: Whether required output paths are missing. + """ + return bool(self._get_missing_required_paths(label)) + + def _get_missing_required_paths(self, label: str) -> set: + """ + Get missing required output path job types for a species/TS. + + Args: + label (str): The species label. + + Returns: + set: Job types with missing required output paths. + """ + if label not in self.output or 'paths' not in self.output[label]: + return set() + path_map = { + 'opt': 'geo', + 'freq': 'freq', + 'sp': 'sp', + 'composite': 'composite', + } + missing = set() + for job_type, path_key in path_map.items(): + if job_type == 'composite': + required = self.composite_method is not None + else: + required = self.job_types.get(job_type, False) + if not required: + continue + if self.species_dict[label].number_of_atoms == 1 and job_type in ['opt', 'freq']: + continue + if self.output[label]['job_types'].get(job_type, False) and not self.output[label]['paths'].get(path_key, ''): + missing.add(job_type) + return missing + + def _sanitize_restart_output(self) -> None: + """ + Ensure restart output state is internally consistent (e.g., convergence without paths). + """ + for label in list(self.output.keys()): + if label not in self.species_dict: + continue + missing_job_types = self._get_missing_required_paths(label) + if self.output[label].get('convergence') and missing_job_types: + self.output[label]['convergence'] = False + if 'job_types' in self.output[label]: + for job_type in missing_job_types: + self.output[label]['job_types'][job_type] = False + def get_server_job_ids(self, specific_server: Optional[str] = None): """ Check job status on a specific server or on all active servers, get a list of relevant running job IDs. @@ -3586,6 +3655,7 @@ def troubleshoot_ess(self, # During TS conf_opt screening, avoid switching mid-batch since switch_ts() deletes all # running jobs for this TS label and can discard other viable TS guesses still running. if job.job_type == 'conf_opt': + logger.debug(f'Deferring TS switch for {label} during conf_opt batch screening.') self.save_restart_dict() return None logger.info(f'TS {label} did not converge. ' @@ -3671,7 +3741,13 @@ def delete_all_species_jobs(self, label: str): logger.info(f'Deleted job {job_name}') job.delete() self.running_jobs[label] = list() - self.output[label]['paths'] = {key: '' if key != 'irc' else list() for key in self.output[label]['paths'].keys()} + if label in self.output: + self.output[label]['convergence'] = False + for key in ['opt', 'freq', 'sp', 'composite', 'fine']: + if key in self.output[label]['job_types']: + self.output[label]['job_types'][key] = False + self.output[label]['paths'] = {key: '' if key != 'irc' else list() + for key in self.output[label]['paths'].keys()} def restore_running_jobs(self): """ From 989a9fd1df4fbda8aaf9c044976c370ca0aed7d7 Mon Sep 17 00:00:00 2001 From: Calvin Pieters Date: Sat, 28 Mar 2026 15:58:32 +0300 Subject: [PATCH 04/14] Deduplicate and format methods in the TS report Ensure that successful and unsuccessful transition state generation methods are listed uniquely and formatted using join to avoid trailing commas in the species report. --- arc/species/species.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/arc/species/species.py b/arc/species/species.py index a94ce01c00..f5ae77a234 100644 --- a/arc/species/species.py +++ b/arc/species/species.py @@ -1536,12 +1536,12 @@ def make_ts_report(self): self.ts_report += ':\n' if self.successful_methods: self.ts_report += 'Methods that successfully generated a TS guess:\n' - for successful_method in self.successful_methods: - self.ts_report += successful_method + ',' + unique_successful_methods = list(dict.fromkeys(self.successful_methods)) + self.ts_report += ','.join(unique_successful_methods) if self.unsuccessful_methods: self.ts_report += '\nMethods that were unsuccessfully in generating a TS guess:\n' - for unsuccessful_method in self.unsuccessful_methods: - self.ts_report += unsuccessful_method + ',' + unique_unsuccessful_methods = list(dict.fromkeys(self.unsuccessful_methods)) + self.ts_report += ','.join(unique_unsuccessful_methods) if not self.ts_guesses_exhausted: self.ts_report += f'\nThe method that generated the best TS guess and its output used for the ' \ f'optimization: {self.chosen_ts_method}\n' From 16207d341e8b0989cfb7f688c9ed64e58e788c86 Mon Sep 17 00:00:00 2001 From: Calvin Pieters Date: Sat, 28 Mar 2026 17:22:29 +0300 Subject: [PATCH 05/14] Improve provenance graph structure and visualization - Update graph logic to correctly link jobs to parent jobs, troubleshooting diamonds, or TS selection decisions instead of always defaulting to the last node. - Preserve intentional newlines in wrapped labels to improve node readability. - Ensure the provenance YAML file is saved with an updated timestamp even when the graphviz package is unavailable. - Add support for visualizing TS guess selection failure events as decision nodes. --- arc/plotter.py | 51 +++++++++++++++++++++++------------- arc/plotter_test.py | 64 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 89 insertions(+), 26 deletions(-) diff --git a/arc/plotter.py b/arc/plotter.py index 3ef287b166..c84c66aff8 100644 --- a/arc/plotter.py +++ b/arc/plotter.py @@ -66,8 +66,11 @@ def _sanitize_graphviz_id(value: str) -> str: def _wrap_graph_label(text: str, width: int = 24) -> str: - """Wrap long labels so graph nodes stay readable.""" - return '\n'.join(textwrap.wrap(str(text), width=width)) if text else '' + """Wrap long labels so graph nodes stay readable, preserving intentional newlines.""" + if not text: + return '' + return '\n'.join(line for part in str(text).split('\n') + for line in (textwrap.wrap(part, width=width) or [''])) def save_provenance_artifacts(project_directory: str, @@ -89,11 +92,11 @@ def save_provenance_artifacts(project_directory: str, dot_path = os.path.join(output_directory, 'provenance.dot') svg_path = os.path.join(output_directory, 'provenance.svg') - save_yaml_file(path=yml_path, content=provenance) - run_label = provenance.get('project', 'ARC run') if graphviz is None: logger.warning('The graphviz Python package is not available, so ARC will only save provenance.yml.') + provenance['updated_at'] = datetime.datetime.now().isoformat(timespec='seconds') + save_yaml_file(path=yml_path, content=provenance) return {'yml': yml_path, 'dot': None, 'svg': None} graph = graphviz.Digraph( @@ -113,8 +116,11 @@ def save_provenance_artifacts(project_directory: str, run_text += f'\nend: {run_footer}' graph.node(run_node_id, _wrap_graph_label(run_text, width=32), shape='oval', fillcolor='lightgoldenrod1') - species_nodes, job_nodes = dict(), dict() - last_node_by_label = dict() + species_nodes = dict() + job_nodes = dict() + # Track the most recent decision node (troubleshoot / TS selection) per label, + # so that follow-up jobs spawned by that decision connect from the diamond. + last_decision_by_label = dict() for event in provenance.get('events', list()): event_type = event.get('event_type', '') @@ -127,7 +133,6 @@ def save_provenance_artifacts(project_directory: str, graph.node(species_node_id, _wrap_graph_label(species_text), fillcolor='aliceblue') graph.edge(run_node_id, species_node_id) species_nodes[label] = species_node_id - last_node_by_label[label] = species_node_id if event_type == 'job_started': job_key = event.get('job_key', event.get('job_name', 'job')) @@ -138,12 +143,21 @@ def save_provenance_artifacts(project_directory: str, if event.get('level'): job_text += f"\n{event['level']}" graph.node(job_node_id, _wrap_graph_label(job_text), fillcolor='white') - source_node_id = run_node_id if label is None else last_node_by_label.get(label, species_nodes.get(label)) - if source_node_id is not None: - edge_label = event.get('provenance_reason') or '' - graph.edge(source_node_id, job_node_id, label=edge_label) - if label is not None: - last_node_by_label[label] = job_node_id + + # Determine the source node for this job's incoming edge. + parent_job = event.get('provenance_parent_job') + reason = event.get('provenance_reason', '') + if parent_job and label in last_decision_by_label: + # A decision (troubleshoot / TS selection) preceded this job — connect from it. + source_node_id = last_decision_by_label.pop(label) + elif parent_job: + # Rerun or other child job — connect from the parent job node. + parent_key = f'{label}:{parent_job}' + source_node_id = job_nodes.get(parent_key, species_nodes.get(label, run_node_id)) + else: + # Normal first-launch job — connect from the species node. + source_node_id = species_nodes.get(label, run_node_id) + graph.edge(source_node_id, job_node_id, label=reason) job_nodes[job_key] = job_node_id elif event_type == 'job_finished': @@ -163,16 +177,17 @@ def save_provenance_artifacts(project_directory: str, result_text += f"\n{', '.join(event['keywords'])}" graph.node(result_node_id, _wrap_graph_label(result_text), shape='note', fillcolor='cornsilk') graph.edge(job_nodes[job_key], result_node_id) - if label is not None: - last_node_by_label[label] = result_node_id - elif event_type in ['ts_guess_selected', 'job_troubleshooting']: + elif event_type in ('ts_guess_selected', 'ts_guess_selection_failed', 'job_troubleshooting'): decision_node_id = _sanitize_graphviz_id(f"decision_{event.get('event_id', 0)}") if event_type == 'ts_guess_selected': decision_text = f"Select TS guess {event.get('selected_index')}" if event.get('method'): decision_text += f"\n{event['method']}" fillcolor = 'lavender' + elif event_type == 'ts_guess_selection_failed': + decision_text = 'TS guess selection\nfailed' + fillcolor = 'mistyrose' else: decision_text = f"Troubleshoot {event.get('job_name', '')}" if event.get('methods'): @@ -180,13 +195,13 @@ def save_provenance_artifacts(project_directory: str, fillcolor = 'moccasin' graph.node(decision_node_id, _wrap_graph_label(decision_text), shape='diamond', fillcolor=fillcolor) source_job_key = event.get('job_key') - source_node_id = job_nodes.get(source_job_key) if source_job_key else last_node_by_label.get(label) + source_node_id = job_nodes.get(source_job_key) if source_job_key else species_nodes.get(label) if source_node_id is None and label is not None: source_node_id = species_nodes.get(label) if source_node_id is not None: graph.edge(source_node_id, decision_node_id) if label is not None: - last_node_by_label[label] = decision_node_id + last_decision_by_label[label] = decision_node_id elif event_type == 'species_initialized' and label in species_nodes: continue diff --git a/arc/plotter_test.py b/arc/plotter_test.py index 0a300f6ab6..20b07656d6 100644 --- a/arc/plotter_test.py +++ b/arc/plotter_test.py @@ -218,6 +218,21 @@ def test_save_irc_traj_animation(self): plotter.save_irc_traj_animation(irc_f_path, irc_r_path, out_path) self.assertTrue(os.path.isfile(out_path)) + def test_wrap_graph_label(self): + """Test that _wrap_graph_label preserves intentional newlines.""" + # Intentional newlines should be preserved, not collapsed. + result = plotter._wrap_graph_label("opt\nopt_a1\ngaussian\nwb97xd/def2tzvp", width=30) + lines = result.split('\n') + self.assertEqual(lines[0], 'opt') + self.assertEqual(lines[1], 'opt_a1') + self.assertEqual(lines[2], 'gaussian') + self.assertEqual(lines[3], 'wb97xd/def2tzvp') + # Long single lines should still be wrapped. + result = plotter._wrap_graph_label("this is a very long label that should be wrapped", width=20) + self.assertTrue(all(len(line) <= 20 for line in result.split('\n'))) + # Empty string returns empty. + self.assertEqual(plotter._wrap_graph_label(''), '') + def test_save_provenance_artifacts(self): """Test saving ARC provenance YAML / Graphviz artifacts.""" project = 'arc_project_for_testing_delete_after_usage' @@ -228,19 +243,37 @@ def test_save_provenance_artifacts(self): 'started_at': '2026-03-15T10:00:00', 'ended_at': '2026-03-15T10:05:00', 'events': [ - {'event_id': 1, 'event_type': 'species_initialized', 'timestamp': '2026-03-15T10:00:00', 'label': 'spc1'}, - {'event_id': 2, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:00:01', + {'event_id': 1, 'event_type': 'species_initialized', 'timestamp': '2026-03-15T10:00:00', + 'label': 'spc1'}, + {'event_id': 2, 'event_type': 'species_initialized', 'timestamp': '2026-03-15T10:00:00', + 'label': 'TS0', 'is_ts': True}, + {'event_id': 3, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:00:01', 'label': 'spc1', 'job_key': 'spc1:opt_a1', 'job_name': 'opt_a1', 'job_type': 'opt', 'job_adapter': 'gaussian', 'level': 'b3lyp/6-31g(d)'}, - {'event_id': 3, 'event_type': 'job_finished', 'timestamp': '2026-03-15T10:01:00', - 'label': 'spc1', 'job_key': 'spc1:opt_a1', 'job_name': 'opt_a1', 'job_type': 'opt', - 'status': 'done', 'run_time': '0:01:00'}, - {'event_id': 4, 'event_type': 'job_troubleshooting', 'timestamp': '2026-03-15T10:01:05', + {'event_id': 4, 'event_type': 'job_finished', 'timestamp': '2026-03-15T10:01:00', + 'label': 'spc1', 'job_key': 'spc1:opt_a1', 'status': 'done', 'run_time': '0:01:00'}, + {'event_id': 5, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:01:01', + 'label': 'spc1', 'job_key': 'spc1:freq_a2', 'job_name': 'freq_a2', 'job_type': 'freq', + 'job_adapter': 'gaussian', 'level': 'b3lyp/6-31g(d)'}, + {'event_id': 6, 'event_type': 'job_finished', 'timestamp': '2026-03-15T10:01:30', + 'label': 'spc1', 'job_key': 'spc1:freq_a2', 'status': 'errored', + 'run_time': '0:00:30', 'keywords': ['memory']}, + {'event_id': 7, 'event_type': 'job_troubleshooting', 'timestamp': '2026-03-15T10:01:35', 'label': 'spc1', 'job_key': 'spc1:freq_a2', 'job_name': 'freq_a2', 'job_type': 'freq', 'methods': ['memory']}, - {'event_id': 5, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:01:10', + {'event_id': 8, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:01:40', 'label': 'spc1', 'job_key': 'spc1:freq_a3', 'job_name': 'freq_a3', 'job_type': 'freq', - 'job_adapter': 'gaussian', 'provenance_reason': 'ess_troubleshoot'}, + 'job_adapter': 'gaussian', 'provenance_parent_job': 'freq_a2', + 'provenance_reason': 'ess_troubleshoot'}, + {'event_id': 9, 'event_type': 'job_finished', 'timestamp': '2026-03-15T10:02:00', + 'label': 'spc1', 'job_key': 'spc1:freq_a3', 'status': 'done', 'run_time': '0:00:20'}, + {'event_id': 10, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:02:01', + 'label': 'TS0', 'job_key': 'TS0:tsg0', 'job_name': 'tsg0', 'job_type': 'tsg', + 'job_adapter': 'autotst'}, + {'event_id': 11, 'event_type': 'job_finished', 'timestamp': '2026-03-15T10:03:00', + 'label': 'TS0', 'job_key': 'TS0:tsg0', 'status': 'done'}, + {'event_id': 12, 'event_type': 'ts_guess_selected', 'timestamp': '2026-03-15T10:03:01', + 'label': 'TS0', 'selected_index': 0, 'method': 'autotst', 'energy': -154.321}, ], } paths = plotter.save_provenance_artifacts(project_directory=project_directory, provenance=provenance) @@ -249,8 +282,23 @@ def test_save_provenance_artifacts(self): self.assertTrue(os.path.isfile(paths['dot'])) with open(paths['dot'], 'r') as f: dot = f.read() + # Species and job nodes are present. self.assertIn('spc1', dot) self.assertIn('opt_a1', dot) + self.assertIn('TS0', dot) + # Troubleshoot diamond and edge label rendered. + self.assertIn('Troubleshoot', dot) + self.assertIn('ess_troubleshoot', dot) + # TS guess selection diamond rendered. + self.assertIn('Select TS guess 0', dot) + self.assertIn('autotst', dot) + # Errored job node coloured correctly. + self.assertIn('mistyrose', dot) + # Normal jobs (opt_a1, freq_a2) connect from the species node, not from each other. + self.assertIn('species_spc1 -> job_spc1_opt_a1', dot) + self.assertIn('species_spc1 -> job_spc1_freq_a2', dot) + # Troubleshoot follow-up connects from the decision diamond, not the species node. + self.assertIn('decision_7 -> job_spc1_freq_a3', dot) @classmethod From 460fb81086ee466f419b975626f12f17a8d85bc9 Mon Sep 17 00:00:00 2001 From: Calvin Pieters Date: Sat, 28 Mar 2026 17:22:48 +0300 Subject: [PATCH 06/14] Fix TS guess tracking and add scheduler unit tests - Use stable indices for TS guesses to ensure correct mapping between jobs and guess objects during conformer optimization. - Add unit tests for provenance deduplication, restart output sanitization, and multi-species label handling in the Scheduler. --- arc/scheduler.py | 21 +++++++--- arc/scheduler_test.py | 94 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 108 insertions(+), 7 deletions(-) diff --git a/arc/scheduler.py b/arc/scheduler.py index 536ecfdd61..4d137516fc 100644 --- a/arc/scheduler.py +++ b/arc/scheduler.py @@ -1258,14 +1258,18 @@ def run_ts_conformer_jobs(self, label: str): successful_tsgs = [tsg for tsg in self.species_dict[label].ts_guesses if tsg.success] if len(successful_tsgs) > 1: self.job_dict[label]['conf_opt'] = dict() - for i, tsg in enumerate(successful_tsgs): + for tsg in successful_tsgs: + if tsg.index is None: + existing_indices = [guess.index for guess in self.species_dict[label].ts_guesses + if guess.index is not None] + tsg.index = max(existing_indices or [-1]) + 1 self.run_job(label=label, xyz=tsg.initial_xyz, level_of_theory=self.ts_guess_level, job_type='conf_opt', - conformer=i, + conformer=tsg.index, ) - tsg.conformer_index = i # Store the conformer index in the TSGuess object to match them later. + tsg.conformer_index = tsg.index # Use a stable identifier for mapping back to TSGuess. elif len(successful_tsgs) == 1: if 'opt' not in self.job_dict[label].keys() and 'composite' not in self.job_dict[label].keys(): # proceed only if opt (/composite) not already spawned @@ -2051,9 +2055,14 @@ def parse_conformer(self, xyz = parser.parse_geometry(log_file_path=job.local_path_to_output_file) energy = parser.parse_e_elect(log_file_path=job.local_path_to_output_file) if self.species_dict[label].is_ts: - self.species_dict[label].ts_guesses[i].energy = energy - self.species_dict[label].ts_guesses[i].opt_xyz = xyz - self.species_dict[label].ts_guesses[i].index = i + tsg = next((guess for guess in self.species_dict[label].ts_guesses + if guess.conformer_index == i), None) + if tsg is None: + logger.warning(f'Could not find TSGuess for conformer {i} of {label} ' + f'(expected a matching conformer_index); skipping.') + return False + tsg.energy = energy + tsg.opt_xyz = xyz if energy is not None: logger.debug(f'Energy for TSGuess {i} of {label} is {energy:.2f}') else: diff --git a/arc/scheduler_test.py b/arc/scheduler_test.py index 77e8123092..de48aef11c 100644 --- a/arc/scheduler_test.py +++ b/arc/scheduler_test.py @@ -757,13 +757,105 @@ def test_add_label_to_unique_species_labels(self): self.assertEqual(unique_label, 'new_species_15_1') self.assertEqual(self.sched2.unique_species_labels, ['methylamine', 'C2H6', 'CtripCO', 'new_species_15', 'new_species_15_0', 'new_species_15_1']) + def test_initialize_provenance_dedup_on_restart(self): + """Test that _initialize_provenance does not re-emit species_initialized for species already in the log.""" + spc = ARCSpecies(label='ethanol', smiles='CCO') + project_directory = os.path.join(ARC_PATH, 'Projects', 'arc_project_for_testing_delete_after_usage_prov') + os.makedirs(os.path.join(project_directory, 'output'), exist_ok=True) + # Write a fake provenance file that already has ethanol initialized. + from arc.common import save_yaml_file + save_yaml_file(path=os.path.join(project_directory, 'output', 'provenance.yml'), + content={'version': 1, 'project': 'test', 'run_id': 'old_run', + 'started_at': '2026-01-01T00:00:00', + 'events': [{'event_id': 1, 'event_type': 'species_initialized', + 'label': 'ethanol', 'is_ts': False}]}) + sched = Scheduler(project='test_prov_dedup', ess_settings=self.ess_settings, + species_list=[spc], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=project_directory, + testing=True, job_types=initialize_job_types()) + init_events = [e for e in sched.provenance['events'] + if e['event_type'] == 'species_initialized' and e.get('label') == 'ethanol'] + self.assertEqual(len(init_events), 1, 'species_initialized should not be duplicated on restart') + # New run should get its own run_id, not the old one. + self.assertNotEqual(sched.provenance['run_id'], 'old_run') + shutil.rmtree(project_directory, ignore_errors=True) + + def test_sanitize_restart_output(self): + """Test that _sanitize_restart_output resets convergence when paths are missing.""" + spc = ARCSpecies(label='H2O', smiles='O') + output = { + 'H2O': { + 'paths': {'geo': '', 'freq': '', 'sp': '', 'composite': ''}, + 'restart': '', 'convergence': True, + 'job_types': {'conf_opt': False, 'conf_sp': False, 'opt': True, 'freq': True, 'sp': True, + 'rotors': False, 'irc': False, 'fine': False, 'composite': False}, + } + } + sched = Scheduler(project='test_sanitize', ess_settings=self.ess_settings, + species_list=[spc], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=self.project_directory, + testing=True, job_types=initialize_job_types(), + restart_dict={'output': output}) + self.assertFalse(sched.output['H2O']['convergence']) + for key in ['opt', 'freq', 'sp']: + self.assertFalse(sched.output['H2O']['job_types'][key]) + + def test_delete_all_species_jobs_resets_output(self): + """Test that delete_all_species_jobs clears convergence, job_types, and paths.""" + spc = ARCSpecies(label='CH4', smiles='C') + output = { + 'CH4': { + 'paths': {'geo': 'some/path.out', 'freq': 'freq.out', 'sp': 'sp.out', 'composite': ''}, + 'restart': '', 'convergence': True, + 'job_types': {'conf_opt': False, 'conf_sp': False, 'opt': True, 'freq': True, 'sp': True, + 'rotors': False, 'irc': False, 'fine': True, 'composite': False}, + } + } + sched = Scheduler(project='test_delete_jobs', ess_settings=self.ess_settings, + species_list=[spc], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=self.project_directory, + testing=True, job_types=initialize_job_types(), + restart_dict={'output': output}) + sched.running_jobs['CH4'] = [] + sched.delete_all_species_jobs(label='CH4') + self.assertFalse(sched.output['CH4']['convergence']) + for key in ['opt', 'freq', 'sp', 'fine']: + self.assertFalse(sched.output['CH4']['job_types'][key]) + self.assertEqual(sched.output['CH4']['paths']['geo'], '') + + def test_provenance_multi_species_label(self): + """Test that provenance handles multi-species (list) labels by joining them.""" + spc1 = ARCSpecies(label='H2', smiles='[H][H]') + spc2 = ARCSpecies(label='O2', smiles='[O][O]') + sched = Scheduler(project='test_multi_label', ess_settings=self.ess_settings, + species_list=[spc1, spc2], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=self.project_directory, + testing=True, job_types=initialize_job_types()) + sched.record_provenance_event(event_type='test_event', label='H2+O2') + event = sched.provenance['events'][-1] + self.assertEqual(event['label'], 'H2+O2') + self.assertIsInstance(event['label'], str) + @classmethod def tearDownClass(cls): """ A function that is run ONCE after all unit tests in this class. Delete all project directories created during these unit tests """ - projects = ['arc_project_for_testing_delete_after_usage3', 'arc_project_for_testing_delete_after_usage6'] + projects = ['arc_project_for_testing_delete_after_usage3', 'arc_project_for_testing_delete_after_usage6', + 'arc_project_for_testing_delete_after_usage_prov'] for project in projects: project_directory = os.path.join(ARC_PATH, 'Projects', project) shutil.rmtree(project_directory, ignore_errors=True) From 4f0882f795a71ab7252225732cdea4b1b1b7e74a Mon Sep 17 00:00:00 2001 From: Calvin Pieters Date: Sat, 28 Mar 2026 17:23:06 +0300 Subject: [PATCH 07/14] Fix TS report typo and update test expectations - Correct "unsuccessfully" to "unsuccessful" in the transition state report string. - Update unit tests to reflect the deduplication of generation methods and the removal of trailing commas in the report output. --- arc/species/species.py | 2 +- arc/species/species_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/arc/species/species.py b/arc/species/species.py index f5ae77a234..3a2bf32d1c 100644 --- a/arc/species/species.py +++ b/arc/species/species.py @@ -1539,7 +1539,7 @@ def make_ts_report(self): unique_successful_methods = list(dict.fromkeys(self.successful_methods)) self.ts_report += ','.join(unique_successful_methods) if self.unsuccessful_methods: - self.ts_report += '\nMethods that were unsuccessfully in generating a TS guess:\n' + self.ts_report += '\nMethods that were unsuccessful in generating a TS guess:\n' unique_unsuccessful_methods = list(dict.fromkeys(self.unsuccessful_methods)) self.ts_report += ','.join(unique_unsuccessful_methods) if not self.ts_guesses_exhausted: diff --git a/arc/species/species_test.py b/arc/species/species_test.py index 8074dd8c96..7f0fcd6ec2 100644 --- a/arc/species/species_test.py +++ b/arc/species/species_test.py @@ -1201,7 +1201,7 @@ def test_from_dict(self): 'ts_guesses_exhausted': False, 'ts_number': 0, 'ts_report': 'TS method summary for TS0 in C3_1 <=> C3_2:\n' 'Methods that successfully generated a TS guess:\n' - 'autotst,autotst,autotst,autotst,gcn,gcn,gcn,gcn,gcn,gcn,gcn,gcn,gcn,gcn,kinbot,kinbot,\n' + 'autotst,gcn,kinbot\n' 'The method that generated the best TS guess and its output used ' 'for the optimization: gcn\n', 'tsg_spawned': True, 'unsuccessful_methods': []} From 7f53dd875eb9808120bd7771f608294f2bf7a1f7 Mon Sep 17 00:00:00 2001 From: Calvin Pieters Date: Sat, 28 Mar 2026 19:30:55 +0300 Subject: [PATCH 08/14] Updates --- arc/scheduler.py | 8 ++++++-- arc/scheduler_test.py | 45 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/arc/scheduler.py b/arc/scheduler.py index 4d137516fc..4f7f70f44c 100644 --- a/arc/scheduler.py +++ b/arc/scheduler.py @@ -528,7 +528,11 @@ def _initialize_provenance(self): logger.warning('Could not parse existing provenance.yml; starting a fresh provenance log.') provenance = None if isinstance(provenance, dict): - self.provenance['events'] = provenance.get('events', list()) + raw_events = provenance.get('events', list()) + if isinstance(raw_events, list) and all(isinstance(e, dict) for e in raw_events): + self.provenance['events'] = raw_events + else: + logger.warning('Existing provenance.yml has invalid events; starting with an empty event log.') already_initialized = {e['label'] for e in self.provenance['events'] if e.get('event_type') == 'species_initialized' and 'label' in e} for species in self.species_list: @@ -1263,13 +1267,13 @@ def run_ts_conformer_jobs(self, label: str): existing_indices = [guess.index for guess in self.species_dict[label].ts_guesses if guess.index is not None] tsg.index = max(existing_indices or [-1]) + 1 + tsg.conformer_index = tsg.index # Set before run_job so restart state is consistent. self.run_job(label=label, xyz=tsg.initial_xyz, level_of_theory=self.ts_guess_level, job_type='conf_opt', conformer=tsg.index, ) - tsg.conformer_index = tsg.index # Use a stable identifier for mapping back to TSGuess. elif len(successful_tsgs) == 1: if 'opt' not in self.job_dict[label].keys() and 'composite' not in self.job_dict[label].keys(): # proceed only if opt (/composite) not already spawned diff --git a/arc/scheduler_test.py b/arc/scheduler_test.py index de48aef11c..01fc947719 100644 --- a/arc/scheduler_test.py +++ b/arc/scheduler_test.py @@ -8,6 +8,7 @@ import unittest import os import shutil +from unittest import mock import arc.parser.parser as parser from arc.checks.ts import check_ts @@ -19,7 +20,7 @@ from arc.imports import settings from arc.reaction import ARCReaction from arc.species.converter import str_to_xyz -from arc.species.species import ARCSpecies +from arc.species.species import ARCSpecies, TSGuess default_levels_of_theory = settings['default_levels_of_theory'] @@ -832,6 +833,48 @@ def test_delete_all_species_jobs_resets_output(self): self.assertFalse(sched.output['CH4']['job_types'][key]) self.assertEqual(sched.output['CH4']['paths']['geo'], '') + def test_conformer_index_set_before_run_job(self): + """Test that tsg.conformer_index is assigned before run_job is called, so restart state is consistent.""" + ts_spc = ARCSpecies(label='TS0', is_ts=True, multiplicity=1, charge=0) + # Use geometries different enough to survive cluster_tsgs() deduplication. + ts_spc.ts_guesses = [ + TSGuess(method='autotst', index=0, success=True, + xyz={'symbols': ('C', 'H', 'H', 'H', 'H'), 'isotopes': (12, 1, 1, 1, 1), + 'coords': ((0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1), (-1, 0, 0))}, + project_directory=self.project_directory), + TSGuess(method='gcn', index=1, success=True, + xyz={'symbols': ('C', 'H', 'H', 'H', 'H'), 'isotopes': (12, 1, 1, 1, 1), + 'coords': ((0, 0, 0), (2, 0, 0), (0, 2, 0), (0, 0, 2), (-2, 0, 0))}, + project_directory=self.project_directory), + ] + sched = Scheduler(project='test_conf_index_order', ess_settings=self.ess_settings, + species_list=[ts_spc], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + ts_guess_level=Level(repr=default_levels_of_theory['ts_guesses']), + project_directory=self.project_directory, + testing=True, job_types=initialize_job_types()) + # Track conformer_index values observed inside run_job. + observed = [] + + def capturing_run_job(**kwargs): + conformer = kwargs.get('conformer') + if conformer is not None: + tsg = next((g for g in ts_spc.ts_guesses if g.index == conformer), None) + observed.append((conformer, tsg.conformer_index if tsg else None)) + + with mock.patch.object(sched, 'run_job', side_effect=capturing_run_job), \ + mock.patch('arc.plotter.save_conformers_file'): + sched.run_ts_conformer_jobs(label='TS0') + + # Every call to run_job should have seen conformer_index already set. + self.assertTrue(len(observed) >= 2, f'Expected at least 2 conf_opt jobs, got {len(observed)}') + for conformer_idx, conformer_index_value in observed: + self.assertIsNotNone(conformer_index_value, + f'conformer_index was None when run_job was called for conformer {conformer_idx}') + self.assertEqual(conformer_idx, conformer_index_value) + def test_provenance_multi_species_label(self): """Test that provenance handles multi-species (list) labels by joining them.""" spc1 = ARCSpecies(label='H2', smiles='[H][H]') From 987037f8eae324e54ee775ae2c0d411ca43cd25a Mon Sep 17 00:00:00 2001 From: Calvin Pieters Date: Sat, 28 Mar 2026 19:48:52 +0300 Subject: [PATCH 09/14] Further updates --- arc/scheduler.py | 9 +++++++-- arc/scheduler_test.py | 21 +++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/arc/scheduler.py b/arc/scheduler.py index 4f7f70f44c..3b5f73c427 100644 --- a/arc/scheduler.py +++ b/arc/scheduler.py @@ -377,6 +377,10 @@ def __init__(self, self.species_list.append(ts_species) self.species_dict[ts_species.label] = ts_species self.initialize_output_dict(ts_species.label) + self.record_provenance_event(event_type='species_initialized', + label=ts_species.label, + is_ts=True, + ) else: # The TS species was already loaded from a restart dict or an Arkane YAML file. ts_species = None @@ -534,7 +538,7 @@ def _initialize_provenance(self): else: logger.warning('Existing provenance.yml has invalid events; starting with an empty event log.') already_initialized = {e['label'] for e in self.provenance['events'] - if e.get('event_type') == 'species_initialized' and 'label' in e} + if e.get('event_type') == 'species_initialized' and isinstance(e.get('label'), str)} for species in self.species_list: if species.label not in already_initialized: self.record_provenance_event(event_type='species_initialized', @@ -548,7 +552,8 @@ def record_provenance_event(self, **data: Any, ): """Append a provenance event and persist the event log.""" - event = {'event_id': len(self.provenance['events']) + 1, + max_id = max((e.get('event_id', 0) for e in self.provenance['events']), default=0) + event = {'event_id': max_id + 1, 'event_type': event_type, 'timestamp': datetime.datetime.now().isoformat(timespec='seconds'), } diff --git a/arc/scheduler_test.py b/arc/scheduler_test.py index 01fc947719..fcb9c39e9b 100644 --- a/arc/scheduler_test.py +++ b/arc/scheduler_test.py @@ -875,6 +875,27 @@ def capturing_run_job(**kwargs): f'conformer_index was None when run_job was called for conformer {conformer_idx}') self.assertEqual(conformer_idx, conformer_index_value) + def test_provenance_records_ts_species_from_reactions(self): + """Test that TS species created from reactions get a species_initialized provenance event.""" + r_spc = ARCSpecies(label='nC3H7', smiles='[CH2]CC') + p_spc = ARCSpecies(label='iC3H7', smiles='C[CH]C') + rxn = ARCReaction(reactants=['nC3H7'], products=['iC3H7'], + r_species=[r_spc], p_species=[p_spc]) + rxn.index = 0 + sched = Scheduler(project='test_ts_prov', ess_settings=self.ess_settings, + species_list=[r_spc, p_spc], + rxn_list=[rxn], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=self.project_directory, + testing=True, job_types=initialize_job_types()) + init_labels = [e['label'] for e in sched.provenance['events'] + if e.get('event_type') == 'species_initialized'] + self.assertIn('nC3H7', init_labels) + self.assertIn('iC3H7', init_labels) + self.assertIn('TS0', init_labels, 'TS species created from a reaction should get a species_initialized event') + def test_provenance_multi_species_label(self): """Test that provenance handles multi-species (list) labels by joining them.""" spc1 = ARCSpecies(label='H2', smiles='[H][H]') From 60a822695e5f0ee47da0f4fcf4524204341d003a Mon Sep 17 00:00:00 2001 From: Calvin Pieters Date: Sat, 11 Apr 2026 20:56:14 +0300 Subject: [PATCH 10/14] Update graph building --- arc/plotter.py | 137 +++++++ arc/plotter_test.py | 56 +++ arc/provenance/__init__.py | 38 ++ arc/provenance/graph.py | 366 +++++++++++++++++ arc/provenance/nodes.py | 386 ++++++++++++++++++ arc/provenance/provenance_test.py | 626 ++++++++++++++++++++++++++++++ arc/scheduler.py | 211 ++++++++-- arc/scheduler_test.py | 57 ++- arc/species/species.py | 8 + 9 files changed, 1853 insertions(+), 32 deletions(-) create mode 100644 arc/provenance/__init__.py create mode 100644 arc/provenance/graph.py create mode 100644 arc/provenance/nodes.py create mode 100644 arc/provenance/provenance_test.py diff --git a/arc/plotter.py b/arc/plotter.py index c84c66aff8..1d8df116e9 100644 --- a/arc/plotter.py +++ b/arc/plotter.py @@ -73,15 +73,135 @@ def _wrap_graph_label(text: str, width: int = 24) -> str: for line in (textwrap.wrap(part, width=width) or [''])) +def render_provenance_graph(prov_graph, run_label: str = 'ARC run') -> 'graphviz.Digraph': + """ + Render a :class:`ProvenanceGraph` as a Graphviz directed graph. + + Node styling by type: + - **species**: box / aliceblue + - **calculation**: box / color by status (honeydew=done, mistyrose=errored, white=pending) + - **data**: note / cornsilk + - **decision**: diamond / color by kind (lavender, moccasin, mistyrose) + + Edge styling by type: + - ``selected_by``: solid green + - ``rejected_by``: dashed red + - ``troubleshot_by``: dashed orange + - ``retried_as`` / ``fine_of``: dotted gray + - others: solid black + + Args: + prov_graph: A :class:`ProvenanceGraph` instance. + run_label (str): Label for the root run node. + + Returns: + graphviz.Digraph: The rendered graph object. + """ + if graphviz is None: + raise ImportError('The graphviz Python package is required for render_provenance_graph(). ' + 'Install it with: conda install -c conda-forge python-graphviz') + gv = graphviz.Digraph( + name='arc_provenance', + comment=f'ARC provenance for {run_label}', + graph_attr={'rankdir': 'LR', 'splines': 'true', 'overlap': 'false'}, + node_attr={'shape': 'box', 'style': 'rounded,filled', 'fillcolor': 'white', 'fontname': 'Helvetica'}, + edge_attr={'fontname': 'Helvetica'}, + ) + + # Node styling lookup + _calc_colors = {'done': 'honeydew', 'errored': 'mistyrose', 'pending': 'white'} + _decision_colors = { + 'ts_guess_selection': 'lavender', + 'ts_guess_selection_failed': 'mistyrose', + 'job_troubleshooting': 'moccasin', + 'conformer_selection': 'lavender', + 'ts_guess_clustering': 'lavender', + 'ts_method_spawning': 'lavender', + 'ts_validation_freq': 'lightyellow', + 'ts_validation_nmd': 'lightyellow', + 'ts_validation_irc': 'lightyellow', + 'ts_switch': 'mistyrose', + } + + # Edge styling lookup + _edge_styles = { + 'selected_by': {'color': 'green3', 'style': 'solid'}, + 'rejected_by': {'color': 'red', 'style': 'dashed'}, + 'troubleshot_by': {'color': 'orange', 'style': 'dashed'}, + 'triggered_by': {'color': 'gray40', 'style': 'solid'}, + 'retried_as': {'color': 'gray60', 'style': 'dotted'}, + 'fine_of': {'color': 'gray60', 'style': 'dotted'}, + 'spawned_by': {'color': 'blue', 'style': 'solid'}, + } + + for node in prov_graph.nodes.values(): + nid = _sanitize_graphviz_id(node.node_id) + ntype = node.node_type + + if ntype == 'species': + lbl = node.label or node.node_id + is_ts = (node.metadata or {}).get('is_ts', False) + if is_ts: + lbl += '\nTS' + gv.node(nid, _wrap_graph_label(lbl), shape='box', fillcolor='aliceblue') + + elif ntype == 'calculation': + parts = [getattr(node, 'job_type', '') or '', getattr(node, 'job_name', '') or ''] + if getattr(node, 'job_adapter', None): + parts.append(node.job_adapter) + if getattr(node, 'level', None): + parts.append(node.level) + lbl = '\n'.join(p for p in parts if p) + status = getattr(node, 'status', 'pending') or 'pending' + fillcolor = _calc_colors.get(status, 'white') + gv.node(nid, _wrap_graph_label(lbl), shape='box', fillcolor=fillcolor) + + elif ntype == 'data': + dk = getattr(node, 'data_kind', '') or '' + val = getattr(node, 'value', None) + lbl = dk + if val is not None and not isinstance(val, (list, dict)): + lbl += f'\n{val}' + gv.node(nid, _wrap_graph_label(lbl), shape='note', fillcolor='cornsilk') + + elif ntype == 'decision': + dk = getattr(node, 'decision_kind', '') or '' + outcome = getattr(node, 'outcome', '') or '' + lbl = dk.replace('_', ' ') + if outcome: + lbl += f'\n{outcome}' + fillcolor = _decision_colors.get(dk, 'lavender') + gv.node(nid, _wrap_graph_label(lbl, width=28), shape='diamond', fillcolor=fillcolor) + + else: + gv.node(nid, _wrap_graph_label(node.node_id)) + + for edge in prov_graph.edges: + src = _sanitize_graphviz_id(edge.source_id) + tgt = _sanitize_graphviz_id(edge.target_id) + etype = edge.edge_type + style_attrs = _edge_styles.get(etype, {}) + label = etype.replace('_', ' ') if etype not in ('belongs_to', 'input_of', 'output_of') else '' + gv.edge(src, tgt, label=label, **style_attrs) + + return gv + + def save_provenance_artifacts(project_directory: str, provenance: dict, + graph=None, ) -> dict: """ Save provenance YAML and render Graphviz artifacts for an ARC run. + When a ``graph`` (:class:`ProvenanceGraph`) is provided, the Graphviz + visualization is built from the graph's typed nodes and edges rather + than the flat event list, producing richer diagrams. + Args: project_directory (str): The ARC project directory. provenance (dict): A provenance dictionary with an ``events`` list. + graph: Optional ProvenanceGraph instance for graph-based rendering. Returns: dict: Paths to generated artifacts. @@ -99,6 +219,23 @@ def save_provenance_artifacts(project_directory: str, save_yaml_file(path=yml_path, content=provenance) return {'yml': yml_path, 'dot': None, 'svg': None} + # Prefer graph-based rendering when a ProvenanceGraph is available. + if graph is not None and len(graph) > 0: + gv_graph = render_provenance_graph(graph, run_label=run_label) + with open(dot_path, 'w') as f: + f.write(gv_graph.source) + try: + svg_data = gv_graph.pipe(format='svg') + except (graphviz.ExecutableNotFound, graphviz.CalledProcessError): + logger.warning('Could not render ARC provenance SVG because Graphviz is not available on this system.') + else: + with open(svg_path, 'wb') as f: + f.write(svg_data) + provenance['updated_at'] = datetime.datetime.now().isoformat(timespec='seconds') + save_yaml_file(path=yml_path, content=provenance) + return {'yml': yml_path, 'dot': dot_path, 'svg': svg_path if os.path.isfile(svg_path) else None} + + # Fallback: event-based rendering (legacy path). graph = graphviz.Digraph( name='arc_provenance', comment=f'ARC provenance for {run_label}', diff --git a/arc/plotter_test.py b/arc/plotter_test.py index 20b07656d6..dea852b275 100644 --- a/arc/plotter_test.py +++ b/arc/plotter_test.py @@ -9,6 +9,11 @@ import shutil import unittest +try: + import graphviz +except ImportError: + graphviz = None + import arc.plotter as plotter from arc.common import ARC_PATH, ARC_TESTING_PATH, read_yaml_file, safe_copy_file from arc.species.converter import str_to_xyz @@ -300,6 +305,57 @@ def test_save_provenance_artifacts(self): # Troubleshoot follow-up connects from the decision diamond, not the species node. self.assertIn('decision_7 -> job_spc1_freq_a3', dot) + def test_render_provenance_graph(self): + """Test Graphviz rendering from a ProvenanceGraph object.""" + from arc.provenance import (ProvenanceGraph, DecisionKind, DataKind, EdgeType) + g = ProvenanceGraph(project='render_test') + sid = g.add_species_node(label='ethanol') + cid = g.add_calculation_node(label='ethanol', job_name='opt_a1', + job_type='opt', job_adapter='gaussian', + level='b3lyp/6-31g(d)', status='done') + did = g.add_data_node(label='ethanol', data_kind=DataKind.energy, value=-79.5) + dec = g.add_decision_node(label='ethanol', + decision_kind=DecisionKind.conformer_selection, + outcome='Selected conformer #0') + g.add_edge(sid, cid, EdgeType.input_of) + g.add_edge(cid, did, EdgeType.output_of) + g.add_edge(did, dec, EdgeType.selected_by) + + if graphviz is not None: + gv = plotter.render_provenance_graph(g, run_label='render_test') + dot_source = gv.source + self.assertIn('ethanol', dot_source) + self.assertIn('opt', dot_source) + self.assertIn('energy', dot_source) + self.assertIn('conformer selection', dot_source) + self.assertIn('honeydew', dot_source) # done calc + self.assertIn('cornsilk', dot_source) # data node + self.assertIn('diamond', dot_source) # decision node + self.assertIn('green3', dot_source) # selected_by edge + + def test_save_provenance_artifacts_with_graph(self): + """Test that save_provenance_artifacts prefers graph-based rendering when a graph is provided.""" + from arc.provenance import (ProvenanceGraph, DecisionKind, EdgeType) + project = 'arc_project_for_testing_delete_after_usage' + project_directory = os.path.join(ARC_PATH, 'Projects', project) + g = ProvenanceGraph(project=project) + sid = g.add_species_node(label='spc1') + cid = g.add_calculation_node(label='spc1', job_name='opt_a1', + job_type='opt', status='done') + g.add_edge(sid, cid, EdgeType.input_of) + provenance = {'project': project, 'events': []} + paths = plotter.save_provenance_artifacts( + project_directory=project_directory, + provenance=provenance, + graph=g, + ) + self.assertTrue(os.path.isfile(paths['yml'])) + if paths['dot'] is not None: + with open(paths['dot'], 'r') as f: + dot = f.read() + # Graph-based rendering uses node IDs like species_1 not event-based species_spc1. + self.assertIn('species_1', dot) + self.assertIn('honeydew', dot) @classmethod def tearDownClass(cls): diff --git a/arc/provenance/__init__.py b/arc/provenance/__init__.py new file mode 100644 index 0000000000..d6da045f38 --- /dev/null +++ b/arc/provenance/__init__.py @@ -0,0 +1,38 @@ +""" +ARC provenance subpackage — directed acyclic graph for computational provenance. + +Tracks the full chain of inputs, calculations, decisions, and outputs that +produce ARC's results. Inspired by AiiDA's DAG model but adapted for ARC's +branching decision trees (TS guess evaluation, conformer selection, +troubleshooting loops). + +Submodules: + - ``nodes``: Node types, edge types, and their data classes. + - ``graph``: ProvenanceGraph container with query and serialization. +""" + +from arc.provenance.graph import ProvenanceGraph +from arc.provenance.nodes import ( + CalculationNode, + DataKind, + DataNode, + DecisionKind, + DecisionNode, + EdgeType, + NodeType, + ProvenanceEdge, + ProvenanceNode, +) + +__all__ = [ + 'ProvenanceGraph', + 'ProvenanceNode', + 'CalculationNode', + 'DataNode', + 'DecisionNode', + 'ProvenanceEdge', + 'NodeType', + 'DataKind', + 'DecisionKind', + 'EdgeType', +] diff --git a/arc/provenance/graph.py b/arc/provenance/graph.py new file mode 100644 index 0000000000..4ef0b7e33a --- /dev/null +++ b/arc/provenance/graph.py @@ -0,0 +1,366 @@ +""" +ProvenanceGraph — a directed acyclic graph for tracking ARC computational provenance. + +The graph stores typed nodes (species, calculations, data artifacts, decisions) +connected by typed directed edges (input_of, selected_by, troubleshot_by, etc.). +It supports forward/backward traversal, flexible queries, and YAML serialization +via the project's standard ``save_yaml_file`` / ``read_yaml_file`` helpers. +""" + +import datetime +import re +from collections import deque +from typing import Any, Dict, List, Optional + +from arc.common import get_logger, read_yaml_file, save_yaml_file +from arc.provenance.nodes import ( + CalculationNode, + DataNode, + DecisionNode, + NodeType, + ProvenanceEdge, + ProvenanceNode, + _enum_val, +) + +logger = get_logger() + +SCHEMA_VERSION = 2 + + +class ProvenanceGraph(object): + """ + A directed acyclic graph for tracking computational provenance. + + Args: + project (str, optional): The ARC project name. + run_id (str, optional): Unique run identifier. + + Attributes: + nodes (Dict[str, ProvenanceNode]): Maps node_id to node. + edges (List[ProvenanceEdge]): All directed edges. + """ + + def __init__(self, + project: Optional[str] = None, + run_id: Optional[str] = None, + ): + self.project = project + self.run_id = run_id or ( + f'{project}_{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}' + if project else None + ) + self.nodes: Dict[str, ProvenanceNode] = {} + self.edges: List[ProvenanceEdge] = [] + self._counter: int = 0 + + # ── Node operations ────────────────────────────────────────────────────── + + def _next_id(self, prefix: str) -> str: + """Generate the next unique node ID with the given prefix.""" + self._counter += 1 + return f'{prefix}_{self._counter}' + + def add_node(self, node: ProvenanceNode) -> str: + """ + Add a node to the graph. + + Args: + node: The node to add. + + Returns: + str: The node's ID. + """ + if node.node_id in self.nodes: + logger.debug(f'Node {node.node_id!r} already exists in the provenance graph; skipping.') + return node.node_id + self.nodes[node.node_id] = node + return node.node_id + + def add_species_node(self, label: Optional[str] = None, is_ts: bool = False, + timestamp: Optional[str] = None) -> str: + """ + Convenience method to add a species node. + + Args: + label: Species label (optional). + is_ts: Whether this is a transition state. + timestamp: Optional ISO timestamp. + + Returns: + str: The new node's ID. + """ + node_id = self._next_id('species') + metadata = {'is_ts': is_ts} if is_ts else None + node = ProvenanceNode(node_id=node_id, node_type=NodeType.species, + label=label, timestamp=timestamp, metadata=metadata) + self.add_node(node) + return node_id + + def add_calculation_node(self, label: Optional[str] = None, **kwargs) -> str: + """ + Convenience method to add a calculation node. + + Returns: + str: The new node's ID. + """ + node_id = self._next_id('calc') + node = CalculationNode(node_id=node_id, label=label, **kwargs) + self.add_node(node) + return node_id + + def add_data_node(self, label: Optional[str] = None, **kwargs) -> str: + """ + Convenience method to add a data node. + + Returns: + str: The new node's ID. + """ + node_id = self._next_id('data') + node = DataNode(node_id=node_id, label=label, **kwargs) + self.add_node(node) + return node_id + + def add_decision_node(self, label: Optional[str] = None, **kwargs) -> str: + """ + Convenience method to add a decision node. + + Returns: + str: The new node's ID. + """ + node_id = self._next_id('decision') + node = DecisionNode(node_id=node_id, label=label, **kwargs) + self.add_node(node) + return node_id + + def get_node(self, node_id: str) -> Optional[ProvenanceNode]: + """Return the node with the given ID, or None.""" + return self.nodes.get(node_id) + + def get_nodes_by_type(self, node_type: str, + label: Optional[str] = None) -> List[ProvenanceNode]: + """Return all nodes of the given type, optionally filtered by label.""" + results = [n for n in self.nodes.values() if n.node_type == _enum_val(node_type)] + if label is not None: + results = [n for n in results if n.label == label] + return results + + def get_nodes_by_label(self, label: str) -> List[ProvenanceNode]: + """Return all nodes associated with the given species label.""" + return [n for n in self.nodes.values() if n.label == label] + + def find_species_node(self, label: str) -> Optional[str]: + """Return the node_id of the species node for the given label, or None.""" + for n in self.nodes.values(): + if n.node_type == 'species' and n.label == label: + return n.node_id + return None + + def find_calc_node(self, label: str, job_name: str) -> Optional[str]: + """Return the node_id of a calculation node matching label and job_name, or None.""" + for n in self.nodes.values(): + if (n.node_type == 'calculation' + and n.label == label + and getattr(n, 'job_name', None) == job_name): + return n.node_id + return None + + def update_node(self, node_id: str, **attrs) -> bool: + """ + Update attributes on an existing node. + + Args: + node_id: The node to update. + **attrs: Attribute names and new values. + + Returns: + bool: True if the node was found and updated. + """ + node = self.nodes.get(node_id) + if node is None: + return False + for key, value in attrs.items(): + setattr(node, key, value) + return True + + # ── Edge operations ────────────────────────────────────────────────────── + + def add_edge(self, + source_id: str, + target_id: str, + edge_type: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> ProvenanceEdge: + """ + Add a directed edge between two nodes. + + Logs a warning if source or target node does not exist in the graph, + but still creates the edge (the node may be added later on restart). + + Args: + source_id: Source node ID. + target_id: Target node ID. + edge_type: One of :class:`EdgeType` values. + metadata: Optional extra data. + + Returns: + The created edge. + """ + if source_id not in self.nodes: + logger.warning(f'Creating edge from non-existent source node {source_id!r}') + if target_id not in self.nodes: + logger.warning(f'Creating edge to non-existent target node {target_id!r}') + edge = ProvenanceEdge(source_id=source_id, target_id=target_id, + edge_type=edge_type, metadata=metadata) + self.edges.append(edge) + return edge + + def get_edges_from(self, node_id: str) -> List[ProvenanceEdge]: + """Return all edges originating from the given node.""" + return [e for e in self.edges if e.source_id == node_id] + + def get_edges_to(self, node_id: str) -> List[ProvenanceEdge]: + """Return all edges pointing to the given node.""" + return [e for e in self.edges if e.target_id == node_id] + + def get_edges_by_type(self, edge_type: str) -> List[ProvenanceEdge]: + """Return all edges of the given type.""" + return [e for e in self.edges if e.edge_type == _enum_val(edge_type)] + + # ── Traversal ──────────────────────────────────────────────────────────── + + def descendants(self, node_id: str) -> List[str]: + """ + Return all node IDs reachable forward from *node_id* (BFS). + + Does not include *node_id* itself. + """ + visited = set() + queue = deque() + for e in self.edges: + if e.source_id == node_id: + queue.append(e.target_id) + while queue: + nid = queue.popleft() + if nid in visited: + continue + visited.add(nid) + for e in self.edges: + if e.source_id == nid and e.target_id not in visited: + queue.append(e.target_id) + return list(visited) + + def ancestors(self, node_id: str) -> List[str]: + """ + Return all node IDs reachable backward from *node_id* (BFS). + + Does not include *node_id* itself. + """ + visited = set() + queue = deque() + for e in self.edges: + if e.target_id == node_id: + queue.append(e.source_id) + while queue: + nid = queue.popleft() + if nid in visited: + continue + visited.add(nid) + for e in self.edges: + if e.target_id == nid and e.source_id not in visited: + queue.append(e.source_id) + return list(visited) + + # ── Query ──────────────────────────────────────────────────────────────── + + def query(self, + node_type: Optional[str] = None, + label: Optional[str] = None, + decision_kind: Optional[str] = None, + data_kind: Optional[str] = None, + status: Optional[str] = None, + ) -> List[ProvenanceNode]: + """ + Flexible query over nodes with optional filters. + + All provided filters are ANDed together. + + Args: + node_type: Filter by NodeType value. + label: Filter by species label. + decision_kind: Filter DecisionNodes by DecisionKind value. + data_kind: Filter DataNodes by DataKind value. + status: Filter CalculationNodes by job status. + + Returns: + List of matching nodes. + """ + results = list(self.nodes.values()) + if node_type is not None: + results = [n for n in results if n.node_type == _enum_val(node_type)] + if label is not None: + results = [n for n in results if n.label == label] + if decision_kind is not None: + results = [n for n in results + if getattr(n, 'decision_kind', None) == _enum_val(decision_kind)] + if data_kind is not None: + results = [n for n in results + if getattr(n, 'data_kind', None) == _enum_val(data_kind)] + if status is not None: + results = [n for n in results + if getattr(n, 'status', None) == status] + return results + + # ── Serialization ──────────────────────────────────────────────────────── + + def as_dict(self) -> Dict[str, Any]: + """Serialize the full graph to a dict for YAML output.""" + d: Dict[str, Any] = { + 'schema_version': SCHEMA_VERSION, + } + if self.project is not None: + d['project'] = self.project + if self.run_id is not None: + d['run_id'] = self.run_id + d['nodes'] = [node.as_dict() for node in self.nodes.values()] + d['edges'] = [edge.as_dict() for edge in self.edges] + return d + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> 'ProvenanceGraph': + """Reconstruct a ProvenanceGraph from a dict (e.g. loaded from YAML).""" + obj = object.__new__(cls) + obj.project = d.get('project') + obj.run_id = d.get('run_id') + obj.nodes = {} + obj.edges = [] + obj._counter = 0 + for node_dict in d.get('nodes', []): + node = ProvenanceNode.from_dict(node_dict) + obj.nodes[node.node_id] = node + # Update counter to avoid ID collisions on restart. + match = re.search(r'_(\d+)$', node.node_id) + if match: + obj._counter = max(obj._counter, int(match.group(1))) + for edge_dict in d.get('edges', []): + obj.edges.append(ProvenanceEdge.from_dict(edge_dict)) + return obj + + def save(self, path: str) -> None: + """Persist the graph to a YAML file.""" + save_yaml_file(path=path, content=self.as_dict()) + + @classmethod + def load(cls, path: str) -> 'ProvenanceGraph': + """Load a ProvenanceGraph from a YAML file.""" + data = read_yaml_file(path) + if not isinstance(data, dict): + raise ValueError(f'Expected a dict in {path}, got {type(data).__name__}') + return cls.from_dict(data) + + def __len__(self) -> int: + return len(self.nodes) + + def __repr__(self) -> str: + return (f'ProvenanceGraph(project={self.project!r}, ' + f'nodes={len(self.nodes)}, edges={len(self.edges)})') diff --git a/arc/provenance/nodes.py b/arc/provenance/nodes.py new file mode 100644 index 0000000000..bebe43a8cc --- /dev/null +++ b/arc/provenance/nodes.py @@ -0,0 +1,386 @@ +""" +Provenance node and edge types for the ARC provenance DAG. + +Defines the fundamental building blocks of the provenance graph: + +- **Enums**: ``NodeType``, ``DataKind``, ``DecisionKind``, ``EdgeType`` + classify nodes and edges. +- **Node classes**: ``ProvenanceNode`` (base), ``CalculationNode``, + ``DataNode``, ``DecisionNode`` represent vertices in the DAG. +- **Edge class**: ``ProvenanceEdge`` represents a directed, typed + relationship between two nodes. + +All classes follow the ``as_dict()`` / ``from_dict()`` serialization +pattern used throughout ARC (see ``arc.job.pipe.pipe_state``). +""" + +import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + + +def _enum_val(val): + """Extract the plain string value from a str-Enum or pass through a string.""" + return val.value if isinstance(val, Enum) else val + + +# ── Enums ──────────────────────────────────────────────────────────────────── + + +class NodeType(str, Enum): + """Types of nodes in the provenance DAG.""" + species = 'species' + data = 'data' + calculation = 'calculation' + decision = 'decision' + + +class DataKind(str, Enum): + """Sub-classification for DataNode content.""" + geometry = 'geometry' + energy = 'energy' + frequencies = 'frequencies' + imaginary_freq = 'imaginary_freq' + irc_path = 'irc_path' + conformer_set = 'conformer_set' + ts_guess_set = 'ts_guess_set' + + +class DecisionKind(str, Enum): + """Sub-classification for DecisionNode decisions.""" + conformer_selection = 'conformer_selection' + ts_guess_clustering = 'ts_guess_clustering' + ts_guess_selection = 'ts_guess_selection' + ts_guess_selection_failed = 'ts_guess_selection_failed' + ts_validation_freq = 'ts_validation_freq' + ts_validation_nmd = 'ts_validation_nmd' + ts_validation_irc = 'ts_validation_irc' + ts_switch = 'ts_switch' + job_troubleshooting = 'job_troubleshooting' + ts_method_spawning = 'ts_method_spawning' + + +class EdgeType(str, Enum): + """Types of directed edges in the provenance DAG.""" + input_of = 'input_of' + output_of = 'output_of' + triggered_by = 'triggered_by' + selected_by = 'selected_by' + rejected_by = 'rejected_by' + spawned_by = 'spawned_by' + troubleshot_by = 'troubleshot_by' + belongs_to = 'belongs_to' + retried_as = 'retried_as' + fine_of = 'fine_of' + + +# ── Node classes ───────────────────────────────────────────────────────────── + + +class ProvenanceNode(object): + """ + Base class for a node in the provenance DAG. + + Args: + node_id (str): Unique identifier (e.g. ``'species_0'``, ``'calc_17'``). + node_type (str): One of :class:`NodeType` values. + label (str, optional): Species label this node is associated with. + timestamp (str, optional): ISO 8601 creation timestamp. + Auto-generated if not provided. + metadata (dict, optional): Arbitrary extra key-value data. + """ + + def __init__(self, + node_id: str, + node_type: str, + label: Optional[str] = None, + timestamp: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + self.node_id = node_id + self.node_type = _enum_val(node_type) + self.label = label + self.timestamp = timestamp or datetime.datetime.now().isoformat(timespec='seconds') + self.metadata = metadata + + def as_dict(self) -> Dict[str, Any]: + """Serialize to a sparse dict (None and empty values omitted).""" + d: Dict[str, Any] = { + 'node_id': self.node_id, + 'node_type': self.node_type, + } + if self.label is not None: + d['label'] = self.label + if self.timestamp is not None: + d['timestamp'] = self.timestamp + if self.metadata: + d['metadata'] = self.metadata + return d + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> 'ProvenanceNode': + """Reconstruct a ProvenanceNode (or appropriate subclass) from a dict.""" + node_type = d.get('node_type', '') + # Dispatch to the correct subclass based on node_type. + # Keys use plain strings so YAML-deserialized values match. + subclass_map = { + 'calculation': CalculationNode, + 'data': DataNode, + 'decision': DecisionNode, + } + target_cls = subclass_map.get(node_type, cls) + if target_cls is not cls: + return target_cls.from_dict(d) + obj = object.__new__(cls) + obj.node_id = d['node_id'] + obj.node_type = d.get('node_type', '') + obj.label = d.get('label') + obj.timestamp = d.get('timestamp') + obj.metadata = d.get('metadata') + return obj + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.node_id!r}, type={self.node_type!r}, label={self.label!r})' + + +class CalculationNode(ProvenanceNode): + """ + A computational job node (opt, freq, sp, scan, tsg, irc, composite, etc.). + + Args: + node_id (str): Unique identifier. + label (str, optional): Species label. + job_name (str, optional): ARC job name (e.g. ``'opt_a1'``). + job_type (str, optional): Job type (e.g. ``'opt'``, ``'freq'``). + job_adapter (str, optional): ESS adapter (e.g. ``'gaussian'``). + level (str, optional): Level of theory string. + status (str, optional): Job outcome: ``'pending'``, ``'done'``, ``'errored'``. + run_time (str, optional): Wall-clock duration string. + conformer (int, optional): Conformer index, if applicable. + tsg (int, optional): TS guess index, if applicable. + ess_trsh_methods (list, optional): Troubleshooting methods applied. + timestamp (str, optional): ISO 8601 creation timestamp. + metadata (dict, optional): Extra data. + """ + + def __init__(self, + node_id: str, + label: Optional[str] = None, + job_name: Optional[str] = None, + job_type: Optional[str] = None, + job_adapter: Optional[str] = None, + level: Optional[str] = None, + status: Optional[str] = None, + run_time: Optional[str] = None, + conformer: Optional[int] = None, + tsg: Optional[int] = None, + ess_trsh_methods: Optional[List[str]] = None, + timestamp: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + super().__init__(node_id=node_id, node_type=NodeType.calculation, + label=label, timestamp=timestamp, metadata=metadata) + self.job_name = job_name + self.job_type = job_type + self.job_adapter = job_adapter + self.level = level + self.status = status + self.run_time = run_time + self.conformer = conformer + self.tsg = tsg + self.ess_trsh_methods = ess_trsh_methods + + def as_dict(self) -> Dict[str, Any]: + d = super().as_dict() + if self.job_name is not None: + d['job_name'] = self.job_name + if self.job_type is not None: + d['job_type'] = self.job_type + if self.job_adapter is not None: + d['job_adapter'] = self.job_adapter + if self.level is not None: + d['level'] = self.level + if self.status is not None: + d['status'] = self.status + if self.run_time is not None: + d['run_time'] = self.run_time + if self.conformer is not None: + d['conformer'] = self.conformer + if self.tsg is not None: + d['tsg'] = self.tsg + if self.ess_trsh_methods: + d['ess_trsh_methods'] = list(self.ess_trsh_methods) + return d + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> 'CalculationNode': + obj = object.__new__(cls) + obj.node_id = d['node_id'] + obj.node_type = d.get('node_type', NodeType.calculation) + obj.label = d.get('label') + obj.timestamp = d.get('timestamp') + obj.metadata = d.get('metadata') + obj.job_name = d.get('job_name') + obj.job_type = d.get('job_type') + obj.job_adapter = d.get('job_adapter') + obj.level = d.get('level') + obj.status = d.get('status') + obj.run_time = d.get('run_time') + obj.conformer = d.get('conformer') + obj.tsg = d.get('tsg') + obj.ess_trsh_methods = d.get('ess_trsh_methods') + return obj + + +class DataNode(ProvenanceNode): + """ + A data artifact node (geometry, energy, frequencies, etc.). + + Args: + node_id (str): Unique identifier. + label (str, optional): Species label. + data_kind (str, optional): One of :class:`DataKind` values. + value: The scalar or small data payload (energy float, freq list, etc.). + source_path (str, optional): Path to the file containing this data. + timestamp (str, optional): ISO 8601 creation timestamp. + metadata (dict, optional): Extra data. + """ + + def __init__(self, + node_id: str, + label: Optional[str] = None, + data_kind: Optional[str] = None, + value: Any = None, + source_path: Optional[str] = None, + timestamp: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + super().__init__(node_id=node_id, node_type=NodeType.data, + label=label, timestamp=timestamp, metadata=metadata) + self.data_kind = _enum_val(data_kind) if data_kind is not None else None + self.value = value + self.source_path = source_path + + def as_dict(self) -> Dict[str, Any]: + d = super().as_dict() + if self.data_kind is not None: + d['data_kind'] = self.data_kind + if self.value is not None: + d['value'] = self.value + if self.source_path is not None: + d['source_path'] = self.source_path + return d + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> 'DataNode': + obj = object.__new__(cls) + obj.node_id = d['node_id'] + obj.node_type = d.get('node_type', NodeType.data) + obj.label = d.get('label') + obj.timestamp = d.get('timestamp') + obj.metadata = d.get('metadata') + obj.data_kind = d.get('data_kind') + obj.value = d.get('value') + obj.source_path = d.get('source_path') + return obj + + +class DecisionNode(ProvenanceNode): + """ + An algorithmic decision point (conformer selection, TS validation, etc.). + + Args: + node_id (str): Unique identifier. + label (str, optional): Species label. + decision_kind (str, optional): One of :class:`DecisionKind` values. + criteria (dict, optional): The selection/rejection criteria applied. + outcome (str, optional): Human-readable summary of the decision result. + timestamp (str, optional): ISO 8601 creation timestamp. + metadata (dict, optional): Extra data. + """ + + def __init__(self, + node_id: str, + label: Optional[str] = None, + decision_kind: Optional[str] = None, + criteria: Optional[Dict[str, Any]] = None, + outcome: Optional[str] = None, + timestamp: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + super().__init__(node_id=node_id, node_type=NodeType.decision, + label=label, timestamp=timestamp, metadata=metadata) + self.decision_kind = _enum_val(decision_kind) if decision_kind is not None else None + self.criteria = criteria + self.outcome = outcome + + def as_dict(self) -> Dict[str, Any]: + d = super().as_dict() + if self.decision_kind is not None: + d['decision_kind'] = self.decision_kind + if self.criteria: + d['criteria'] = self.criteria + if self.outcome is not None: + d['outcome'] = self.outcome + return d + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> 'DecisionNode': + obj = object.__new__(cls) + obj.node_id = d['node_id'] + obj.node_type = d.get('node_type', NodeType.decision) + obj.label = d.get('label') + obj.timestamp = d.get('timestamp') + obj.metadata = d.get('metadata') + obj.decision_kind = d.get('decision_kind') + obj.criteria = d.get('criteria') + obj.outcome = d.get('outcome') + return obj + + +# ── Edge class ─────────────────────────────────────────────────────────────── + + +class ProvenanceEdge(object): + """ + A typed directed edge in the provenance DAG. + + Args: + source_id (str): Node ID of the edge source. + target_id (str): Node ID of the edge target. + edge_type (str): One of :class:`EdgeType` values. + metadata (dict, optional): Arbitrary extra key-value data. + """ + + def __init__(self, + source_id: str, + target_id: str, + edge_type: str, + metadata: Optional[Dict[str, Any]] = None, + ): + self.source_id = source_id + self.target_id = target_id + self.edge_type = _enum_val(edge_type) + self.metadata = metadata + + def as_dict(self) -> Dict[str, Any]: + d: Dict[str, Any] = { + 'source_id': self.source_id, + 'target_id': self.target_id, + 'edge_type': self.edge_type, + } + if self.metadata: + d['metadata'] = self.metadata + return d + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> 'ProvenanceEdge': + obj = object.__new__(cls) + obj.source_id = d['source_id'] + obj.target_id = d['target_id'] + obj.edge_type = d.get('edge_type', '') + obj.metadata = d.get('metadata') + return obj + + def __repr__(self) -> str: + return f'ProvenanceEdge({self.source_id!r} --{self.edge_type}--> {self.target_id!r})' diff --git a/arc/provenance/provenance_test.py b/arc/provenance/provenance_test.py new file mode 100644 index 0000000000..0db5aecef3 --- /dev/null +++ b/arc/provenance/provenance_test.py @@ -0,0 +1,626 @@ +"""Tests for the arc.provenance package — nodes, edges, and ProvenanceGraph.""" + +import os +import shutil +import tempfile +import unittest + +from arc.provenance.graph import SCHEMA_VERSION, ProvenanceGraph +from arc.provenance.nodes import ( + CalculationNode, + DataKind, + DataNode, + DecisionKind, + DecisionNode, + EdgeType, + NodeType, + ProvenanceEdge, + ProvenanceNode, +) + + +class TestEnums(unittest.TestCase): + """Verify that enums are str-based and contain expected values.""" + + def test_node_type_is_str(self): + self.assertIsInstance(NodeType.species, str) + self.assertEqual(NodeType.calculation, 'calculation') + + def test_data_kind_values(self): + self.assertIn('geometry', [dk.value for dk in DataKind]) + self.assertIn('energy', [dk.value for dk in DataKind]) + + def test_decision_kind_values(self): + expected = {'conformer_selection', 'ts_guess_clustering', 'ts_guess_selection', + 'ts_guess_selection_failed', 'ts_validation_freq', 'ts_validation_nmd', + 'ts_validation_irc', 'ts_switch', 'job_troubleshooting', 'ts_method_spawning'} + actual = {dk.value for dk in DecisionKind} + self.assertEqual(expected, actual) + + def test_edge_type_values(self): + self.assertIn('input_of', [et.value for et in EdgeType]) + self.assertIn('selected_by', [et.value for et in EdgeType]) + self.assertIn('rejected_by', [et.value for et in EdgeType]) + + +class TestProvenanceNode(unittest.TestCase): + """Test the base ProvenanceNode class.""" + + def test_creation(self): + node = ProvenanceNode(node_id='species_1', node_type=NodeType.species, label='ethanol') + self.assertEqual(node.node_id, 'species_1') + self.assertEqual(node.node_type, 'species') + self.assertEqual(node.label, 'ethanol') + self.assertIsNotNone(node.timestamp) + + def test_as_dict_sparse(self): + node = ProvenanceNode(node_id='species_1', node_type=NodeType.species) + d = node.as_dict() + self.assertIn('node_id', d) + self.assertIn('node_type', d) + self.assertNotIn('label', d) + self.assertNotIn('metadata', d) + + def test_as_dict_with_metadata(self): + node = ProvenanceNode(node_id='species_1', node_type=NodeType.species, + label='H2O', metadata={'is_ts': True}) + d = node.as_dict() + self.assertEqual(d['metadata'], {'is_ts': True}) + + def test_from_dict_roundtrip(self): + node = ProvenanceNode(node_id='species_1', node_type=NodeType.species, + label='ethanol', metadata={'is_ts': False}) + d = node.as_dict() + restored = ProvenanceNode.from_dict(d) + self.assertEqual(restored.node_id, 'species_1') + self.assertEqual(restored.node_type, 'species') + self.assertEqual(restored.label, 'ethanol') + + def test_from_dict_dispatches_to_subclass(self): + d = {'node_id': 'calc_1', 'node_type': 'calculation', 'job_name': 'opt_a1'} + restored = ProvenanceNode.from_dict(d) + self.assertIsInstance(restored, CalculationNode) + self.assertEqual(restored.job_name, 'opt_a1') + + def test_repr(self): + node = ProvenanceNode(node_id='species_1', node_type=NodeType.species, label='ethanol') + self.assertIn('species_1', repr(node)) + + +class TestCalculationNode(unittest.TestCase): + """Test CalculationNode creation and serialization.""" + + def test_creation(self): + node = CalculationNode(node_id='calc_1', label='ethanol', job_name='opt_a1', + job_type='opt', job_adapter='gaussian', + level='wb97xd/def2-tzvp', status='done') + self.assertEqual(node.node_type, 'calculation') + self.assertEqual(node.job_name, 'opt_a1') + self.assertEqual(node.status, 'done') + + def test_as_dict_sparse(self): + node = CalculationNode(node_id='calc_1', label='ethanol', job_name='opt_a1') + d = node.as_dict() + self.assertIn('job_name', d) + self.assertNotIn('job_adapter', d) + self.assertNotIn('ess_trsh_methods', d) + + def test_from_dict_roundtrip(self): + node = CalculationNode(node_id='calc_1', label='ethanol', job_name='opt_a1', + job_type='opt', status='errored', + ess_trsh_methods=['SCF=QC', 'int=grid=ultrafine']) + d = node.as_dict() + restored = CalculationNode.from_dict(d) + self.assertEqual(restored.job_name, 'opt_a1') + self.assertEqual(restored.status, 'errored') + self.assertEqual(restored.ess_trsh_methods, ['SCF=QC', 'int=grid=ultrafine']) + self.assertIsNone(restored.conformer) + + +class TestDataNode(unittest.TestCase): + """Test DataNode creation and serialization.""" + + def test_creation(self): + node = DataNode(node_id='data_1', label='ethanol', + data_kind=DataKind.energy, value=-79.123456) + self.assertEqual(node.node_type, 'data') + self.assertEqual(node.data_kind, 'energy') + self.assertEqual(node.value, -79.123456) + + def test_from_dict_roundtrip(self): + node = DataNode(node_id='data_1', label='ethanol', + data_kind=DataKind.frequencies, value=[3200.5, 1500.3, 800.1]) + d = node.as_dict() + restored = DataNode.from_dict(d) + self.assertEqual(restored.data_kind, 'frequencies') + self.assertEqual(restored.value, [3200.5, 1500.3, 800.1]) + + +class TestDecisionNode(unittest.TestCase): + """Test DecisionNode creation and serialization.""" + + def test_creation(self): + node = DecisionNode(node_id='decision_1', label='TS0', + decision_kind=DecisionKind.ts_guess_selection, + outcome='Selected TSGuess #3 (energy=-150.2 kJ/mol)') + self.assertEqual(node.node_type, 'decision') + self.assertEqual(node.decision_kind, 'ts_guess_selection') + self.assertIn('TSGuess #3', node.outcome) + + def test_from_dict_roundtrip(self): + node = DecisionNode(node_id='decision_1', label='TS0', + decision_kind=DecisionKind.job_troubleshooting, + criteria={'error_keywords': ['SCF', 'Memory'], + 'applied': 'SCF=QC'}, + outcome='Retrying with SCF=QC') + d = node.as_dict() + restored = DecisionNode.from_dict(d) + self.assertEqual(restored.decision_kind, 'job_troubleshooting') + self.assertEqual(restored.criteria['error_keywords'], ['SCF', 'Memory']) + self.assertEqual(restored.outcome, 'Retrying with SCF=QC') + + +class TestProvenanceEdge(unittest.TestCase): + """Test ProvenanceEdge creation and serialization.""" + + def test_creation(self): + edge = ProvenanceEdge(source_id='species_1', target_id='calc_1', + edge_type=EdgeType.input_of) + self.assertEqual(edge.source_id, 'species_1') + self.assertEqual(edge.edge_type, 'input_of') + + def test_as_dict_sparse(self): + edge = ProvenanceEdge(source_id='a', target_id='b', edge_type=EdgeType.output_of) + d = edge.as_dict() + self.assertNotIn('metadata', d) + + def test_from_dict_roundtrip(self): + edge = ProvenanceEdge(source_id='calc_1', target_id='data_1', + edge_type=EdgeType.output_of, + metadata={'reason': 'rerun'}) + d = edge.as_dict() + restored = ProvenanceEdge.from_dict(d) + self.assertEqual(restored.source_id, 'calc_1') + self.assertEqual(restored.metadata, {'reason': 'rerun'}) + + def test_repr(self): + edge = ProvenanceEdge(source_id='a', target_id='b', edge_type=EdgeType.selected_by) + self.assertIn('selected_by', repr(edge)) + + +class TestProvenanceGraph(unittest.TestCase): + """Test ProvenanceGraph CRUD, traversal, query, and serialization.""" + + def setUp(self): + self.graph = ProvenanceGraph(project='test_project') + + def test_add_species_node(self): + nid = self.graph.add_species_node(label='ethanol') + self.assertIn(nid, self.graph.nodes) + self.assertEqual(self.graph.nodes[nid].node_type, 'species') + self.assertEqual(self.graph.nodes[nid].label, 'ethanol') + + def test_add_calculation_node(self): + nid = self.graph.add_calculation_node(label='ethanol', job_name='opt_a1', + job_type='opt', status='pending') + node = self.graph.get_node(nid) + self.assertIsInstance(node, CalculationNode) + self.assertEqual(node.job_name, 'opt_a1') + + def test_add_data_node(self): + nid = self.graph.add_data_node(label='ethanol', data_kind=DataKind.energy, + value=-79.5) + node = self.graph.get_node(nid) + self.assertIsInstance(node, DataNode) + self.assertEqual(node.value, -79.5) + + def test_add_decision_node(self): + nid = self.graph.add_decision_node(label='TS0', + decision_kind=DecisionKind.ts_guess_selection, + outcome='Selected TSG #2') + node = self.graph.get_node(nid) + self.assertIsInstance(node, DecisionNode) + self.assertEqual(node.outcome, 'Selected TSG #2') + + def test_node_id_auto_increment(self): + id1 = self.graph.add_species_node(label='A') + id2 = self.graph.add_species_node(label='B') + id3 = self.graph.add_calculation_node(label='A', job_name='opt_a1') + self.assertEqual(id1, 'species_1') + self.assertEqual(id2, 'species_2') + self.assertEqual(id3, 'calc_3') + + def test_duplicate_node_skipped(self): + node = ProvenanceNode(node_id='species_1', node_type=NodeType.species, label='X') + self.graph.add_node(node) + self.graph.add_node(node) + self.assertEqual(len(self.graph.nodes), 1) + + def test_add_edge(self): + sid = self.graph.add_species_node(label='ethanol') + cid = self.graph.add_calculation_node(label='ethanol', job_name='opt_a1') + edge = self.graph.add_edge(sid, cid, EdgeType.input_of) + self.assertEqual(len(self.graph.edges), 1) + self.assertEqual(edge.edge_type, 'input_of') + + def test_get_edges_from_and_to(self): + sid = self.graph.add_species_node(label='A') + c1 = self.graph.add_calculation_node(label='A', job_name='opt_a1') + c2 = self.graph.add_calculation_node(label='A', job_name='freq_a2') + self.graph.add_edge(sid, c1, EdgeType.input_of) + self.graph.add_edge(sid, c2, EdgeType.input_of) + self.assertEqual(len(self.graph.get_edges_from(sid)), 2) + self.assertEqual(len(self.graph.get_edges_to(c1)), 1) + + def test_get_nodes_by_type(self): + self.graph.add_species_node(label='A') + self.graph.add_species_node(label='B') + self.graph.add_calculation_node(label='A', job_name='opt') + species_nodes = self.graph.get_nodes_by_type(NodeType.species) + self.assertEqual(len(species_nodes), 2) + calc_nodes = self.graph.get_nodes_by_type(NodeType.calculation) + self.assertEqual(len(calc_nodes), 1) + + def test_get_nodes_by_type_with_label_filter(self): + self.graph.add_species_node(label='A') + self.graph.add_species_node(label='B') + self.graph.add_calculation_node(label='A', job_name='opt') + self.graph.add_calculation_node(label='B', job_name='opt') + a_calcs = self.graph.get_nodes_by_type(NodeType.calculation, label='A') + self.assertEqual(len(a_calcs), 1) + + def test_get_nodes_by_label(self): + self.graph.add_species_node(label='ethanol') + self.graph.add_calculation_node(label='ethanol', job_name='opt') + self.graph.add_calculation_node(label='methane', job_name='opt') + eth_nodes = self.graph.get_nodes_by_label('ethanol') + self.assertEqual(len(eth_nodes), 2) + + def test_find_species_node(self): + sid = self.graph.add_species_node(label='ethanol') + self.assertEqual(self.graph.find_species_node('ethanol'), sid) + self.assertIsNone(self.graph.find_species_node('missing')) + + def test_find_calc_node(self): + self.graph.add_calculation_node(label='A', job_name='opt_a1') + cid = self.graph.find_calc_node('A', 'opt_a1') + self.assertIsNotNone(cid) + self.assertIsNone(self.graph.find_calc_node('A', 'missing')) + + def test_update_node(self): + cid = self.graph.add_calculation_node(label='A', job_name='opt', status='pending') + self.assertTrue(self.graph.update_node(cid, status='done', run_time='00:05:30')) + node = self.graph.get_node(cid) + self.assertEqual(node.status, 'done') + self.assertEqual(node.run_time, '00:05:30') + + def test_update_node_missing(self): + self.assertFalse(self.graph.update_node('nonexistent', status='done')) + + def test_get_edges_by_type(self): + sid = self.graph.add_species_node(label='A') + c1 = self.graph.add_calculation_node(label='A', job_name='opt') + d1 = self.graph.add_data_node(label='A', data_kind=DataKind.energy) + self.graph.add_edge(sid, c1, EdgeType.input_of) + self.graph.add_edge(c1, d1, EdgeType.output_of) + self.assertEqual(len(self.graph.get_edges_by_type(EdgeType.input_of)), 1) + self.assertEqual(len(self.graph.get_edges_by_type(EdgeType.output_of)), 1) + self.assertEqual(len(self.graph.get_edges_by_type(EdgeType.selected_by)), 0) + + # ── Traversal ──────────────────────────────────────────────────────────── + + def test_descendants(self): + """species -> calc -> data -> decision""" + sid = self.graph.add_species_node(label='A') + cid = self.graph.add_calculation_node(label='A', job_name='opt') + did = self.graph.add_data_node(label='A', data_kind=DataKind.geometry) + dec = self.graph.add_decision_node(label='A', decision_kind=DecisionKind.conformer_selection) + self.graph.add_edge(sid, cid, EdgeType.input_of) + self.graph.add_edge(cid, did, EdgeType.output_of) + self.graph.add_edge(did, dec, EdgeType.selected_by) + desc = self.graph.descendants(sid) + self.assertEqual(set(desc), {cid, did, dec}) + self.assertNotIn(sid, desc) + + def test_ancestors(self): + """Reverse traversal.""" + sid = self.graph.add_species_node(label='A') + cid = self.graph.add_calculation_node(label='A', job_name='opt') + did = self.graph.add_data_node(label='A', data_kind=DataKind.energy) + self.graph.add_edge(sid, cid, EdgeType.input_of) + self.graph.add_edge(cid, did, EdgeType.output_of) + anc = self.graph.ancestors(did) + self.assertEqual(set(anc), {sid, cid}) + + def test_no_descendants(self): + sid = self.graph.add_species_node(label='A') + self.assertEqual(self.graph.descendants(sid), []) + + # ── Query ──────────────────────────────────────────────────────────────── + + def test_query_by_node_type(self): + self.graph.add_species_node(label='A') + self.graph.add_calculation_node(label='A', job_name='opt') + results = self.graph.query(node_type=NodeType.species) + self.assertEqual(len(results), 1) + + def test_query_by_decision_kind(self): + self.graph.add_decision_node(label='A', decision_kind=DecisionKind.ts_guess_selection) + self.graph.add_decision_node(label='A', decision_kind=DecisionKind.job_troubleshooting) + results = self.graph.query(decision_kind=DecisionKind.ts_guess_selection) + self.assertEqual(len(results), 1) + + def test_query_by_status(self): + self.graph.add_calculation_node(label='A', job_name='opt', status='done') + self.graph.add_calculation_node(label='A', job_name='freq', status='errored') + done = self.graph.query(status='done') + self.assertEqual(len(done), 1) + self.assertEqual(done[0].job_name, 'opt') + + def test_query_combined_filters(self): + self.graph.add_calculation_node(label='A', job_name='opt', status='done') + self.graph.add_calculation_node(label='B', job_name='opt', status='done') + results = self.graph.query(node_type=NodeType.calculation, label='A', status='done') + self.assertEqual(len(results), 1) + self.assertEqual(results[0].label, 'A') + + # ── Serialization ──────────────────────────────────────────────────────── + + def test_as_dict_structure(self): + self.graph.add_species_node(label='A') + d = self.graph.as_dict() + self.assertEqual(d['schema_version'], SCHEMA_VERSION) + self.assertEqual(d['project'], 'test_project') + self.assertIsInstance(d['nodes'], list) + self.assertIsInstance(d['edges'], list) + + def test_from_dict_roundtrip(self): + sid = self.graph.add_species_node(label='ethanol') + cid = self.graph.add_calculation_node(label='ethanol', job_name='opt_a1', + status='done') + self.graph.add_edge(sid, cid, EdgeType.input_of) + d = self.graph.as_dict() + restored = ProvenanceGraph.from_dict(d) + self.assertEqual(len(restored.nodes), 2) + self.assertEqual(len(restored.edges), 1) + self.assertEqual(restored.project, 'test_project') + self.assertIsInstance(restored.get_node(cid), CalculationNode) + self.assertEqual(restored.get_node(cid).status, 'done') + + def test_restart_continues_counter(self): + """After loading a graph, new node IDs should not collide with existing ones.""" + self.graph.add_species_node(label='A') + self.graph.add_species_node(label='B') + self.graph.add_calculation_node(label='A', job_name='opt') + d = self.graph.as_dict() + restored = ProvenanceGraph.from_dict(d) + new_id = restored.add_species_node(label='C') + # _counter should be at least 3 (from species_1, species_2, calc_3), + # so next ID should be species_4 or higher + self.assertNotIn(new_id, ['species_1', 'species_2', 'calc_3']) + + def test_save_and_load(self): + tmp_dir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, tmp_dir) + path = os.path.join(tmp_dir, 'provenance_graph.yml') + sid = self.graph.add_species_node(label='ethanol') + cid = self.graph.add_calculation_node(label='ethanol', job_name='opt_a1', + job_type='opt', status='done') + did = self.graph.add_data_node(label='ethanol', data_kind=DataKind.energy, + value=-79.5) + dec = self.graph.add_decision_node(label='ethanol', + decision_kind=DecisionKind.conformer_selection, + outcome='Selected conformer #0') + self.graph.add_edge(sid, cid, EdgeType.input_of) + self.graph.add_edge(cid, did, EdgeType.output_of) + self.graph.add_edge(did, dec, EdgeType.selected_by) + self.graph.save(path) + self.assertTrue(os.path.isfile(path)) + loaded = ProvenanceGraph.load(path) + self.assertEqual(len(loaded.nodes), 4) + self.assertEqual(len(loaded.edges), 3) + self.assertIsInstance(loaded.get_node(cid), CalculationNode) + self.assertIsInstance(loaded.get_node(did), DataNode) + self.assertIsInstance(loaded.get_node(dec), DecisionNode) + + def test_len_and_repr(self): + self.assertEqual(len(self.graph), 0) + self.graph.add_species_node(label='A') + self.assertEqual(len(self.graph), 1) + self.assertIn('test_project', repr(self.graph)) + + +class TestProvenanceGraphWorkflow(unittest.TestCase): + """ + Integration-style test: build a realistic provenance graph for a species + going through opt → freq → sp, with a troubleshooting retry on freq. + """ + + def test_realistic_workflow(self): + g = ProvenanceGraph(project='workflow_test') + + # Species initialized + sid = g.add_species_node(label='ethanol') + + # Opt job succeeds + opt_id = g.add_calculation_node(label='ethanol', job_name='opt_a1', + job_type='opt', status='done') + g.add_edge(sid, opt_id, EdgeType.input_of) + opt_geo = g.add_data_node(label='ethanol', data_kind=DataKind.geometry, + source_path='calcs/opt_a1/output.log') + g.add_edge(opt_id, opt_geo, EdgeType.output_of) + + # Freq job fails + freq1_id = g.add_calculation_node(label='ethanol', job_name='freq_a2', + job_type='freq', status='errored') + g.add_edge(opt_geo, freq1_id, EdgeType.input_of) + + # Troubleshooting decision + trsh_id = g.add_decision_node(label='ethanol', + decision_kind=DecisionKind.job_troubleshooting, + criteria={'error_keywords': ['SCF']}, + outcome='Retrying with SCF=QC') + g.add_edge(freq1_id, trsh_id, EdgeType.troubleshot_by) + + # Freq job retried and succeeds + freq2_id = g.add_calculation_node(label='ethanol', job_name='freq_a3', + job_type='freq', status='done', + ess_trsh_methods=['SCF=QC']) + g.add_edge(trsh_id, freq2_id, EdgeType.spawned_by) + g.add_edge(freq1_id, freq2_id, EdgeType.retried_as) + freq_data = g.add_data_node(label='ethanol', data_kind=DataKind.frequencies, + value=[3200.5, 1500.3]) + g.add_edge(freq2_id, freq_data, EdgeType.output_of) + + # SP job succeeds + sp_id = g.add_calculation_node(label='ethanol', job_name='sp_a4', + job_type='sp', status='done') + g.add_edge(opt_geo, sp_id, EdgeType.input_of) + sp_energy = g.add_data_node(label='ethanol', data_kind=DataKind.energy, + value=-79.123456) + g.add_edge(sp_id, sp_energy, EdgeType.output_of) + + # Verify graph structure + self.assertEqual(len(g.nodes), 9) + self.assertEqual(len(g.edges), 9) + + # Verify traversal: ancestors of the final energy should trace back to species + anc = g.ancestors(sp_energy) + self.assertIn(sid, anc) + self.assertIn(opt_id, anc) + self.assertIn(sp_id, anc) + + # Verify query: find all troubleshooting decisions + trsh_decisions = g.query(decision_kind=DecisionKind.job_troubleshooting) + self.assertEqual(len(trsh_decisions), 1) + self.assertEqual(trsh_decisions[0].criteria['error_keywords'], ['SCF']) + + # Verify query: find all errored calculations + errored = g.query(node_type=NodeType.calculation, status='errored') + self.assertEqual(len(errored), 1) + self.assertEqual(errored[0].job_name, 'freq_a2') + + # Verify traversal: descendants of the troubleshooting decision + # should include the retried freq job and its output + desc = g.descendants(trsh_id) + self.assertIn(freq2_id, desc) + self.assertIn(freq_data, desc) + + +class TestEdgeCases(unittest.TestCase): + """Tests for edge cases identified during code review.""" + + def setUp(self): + self.graph = ProvenanceGraph(project='edge_case_test') + + def test_add_edge_warns_on_nonexistent_nodes(self): + """add_edge should still work but log warnings for missing nodes.""" + sid = self.graph.add_species_node(label='A') + edge = self.graph.add_edge(sid, 'nonexistent_target', EdgeType.input_of) + self.assertEqual(len(self.graph.edges), 1) + self.assertEqual(edge.target_id, 'nonexistent_target') + + def test_roundtrip_preserves_zero_value(self): + """DataNode with value=0 (falsy) must survive serialization.""" + nid = self.graph.add_data_node(label='A', data_kind=DataKind.energy, value=0) + d = self.graph.as_dict() + restored = ProvenanceGraph.from_dict(d) + node = restored.get_node(nid) + self.assertIsInstance(node, DataNode) + self.assertEqual(node.value, 0) + + def test_roundtrip_preserves_false_in_metadata(self): + """Metadata with False values must survive serialization.""" + node = ProvenanceNode(node_id='species_99', node_type=NodeType.species, + label='X', metadata={'is_ts': False, 'converged': False}) + self.graph.add_node(node) + d = self.graph.as_dict() + restored = ProvenanceGraph.from_dict(d) + restored_node = restored.get_node('species_99') + self.assertEqual(restored_node.metadata['is_ts'], False) + self.assertEqual(restored_node.metadata['converged'], False) + + def test_roundtrip_omits_empty_ess_trsh_methods(self): + """CalculationNode with ess_trsh_methods=[] should omit it from dict.""" + node = CalculationNode(node_id='calc_99', label='A', ess_trsh_methods=[]) + d = node.as_dict() + self.assertNotIn('ess_trsh_methods', d) + + def test_ancestors_with_diamond_dependency(self): + """DAG diamond: A -> B -> D, A -> C -> D — ancestors(D) = {A, B, C}.""" + a = self.graph.add_species_node(label='A') + b = self.graph.add_calculation_node(label='A', job_name='opt') + c = self.graph.add_calculation_node(label='A', job_name='freq') + d = self.graph.add_data_node(label='A', data_kind=DataKind.energy) + self.graph.add_edge(a, b, EdgeType.input_of) + self.graph.add_edge(a, c, EdgeType.input_of) + self.graph.add_edge(b, d, EdgeType.output_of) + self.graph.add_edge(c, d, EdgeType.output_of) + anc = self.graph.ancestors(d) + self.assertEqual(set(anc), {a, b, c}) + + def test_descendants_handles_self_loop(self): + """If a self-loop is accidentally created, traversal should not infinite-loop.""" + nid = self.graph.add_species_node(label='A') + self.graph.add_edge(nid, nid, EdgeType.input_of) + desc = self.graph.descendants(nid) + self.assertIn(nid, desc) + + def test_query_enum_and_string_equivalence(self): + """Query with NodeType enum and plain string should return identical results.""" + self.graph.add_calculation_node(label='A', job_name='opt', status='done') + r1 = self.graph.query(node_type=NodeType.calculation) + r2 = self.graph.query(node_type='calculation') + self.assertEqual(len(r1), len(r2)) + self.assertEqual(r1[0].node_id, r2[0].node_id) + + def test_counter_with_mixed_prefixes_after_restart(self): + """Counter should track max across ALL prefixes, not per-prefix.""" + self.graph.add_species_node(label='A') # species_1 + self.graph.add_species_node(label='B') # species_2 + self.graph.add_calculation_node(label='A', job_name='opt') # calc_3 + d = self.graph.as_dict() + restored = ProvenanceGraph.from_dict(d) + # Counter should be >= 3, so next ID suffix is >= 4 + new_id = restored.add_data_node(label='A', data_kind=DataKind.energy) + suffix = int(new_id.split('_')[-1]) + self.assertGreaterEqual(suffix, 4) + + def test_render_all_edge_types(self): + """Verify render_provenance_graph handles every EdgeType without errors.""" + try: + import graphviz as gv_mod + except ImportError: + self.skipTest('graphviz not installed') + from arc.plotter import render_provenance_graph + g = ProvenanceGraph(project='edge_type_test') + n1 = g.add_species_node(label='A') + n2 = g.add_calculation_node(label='A', job_name='opt', status='done') + g.add_data_node(label='A', data_kind=DataKind.energy) + g.add_decision_node(label='A', decision_kind=DecisionKind.conformer_selection) + g.add_calculation_node(label='A', job_name='opt2', status='errored') + for et in list(EdgeType): + g.add_edge(n1, n2, et) + gv = render_provenance_graph(g, run_label='test') + dot = gv.source + self.assertIn('species_1', dot) + self.assertIn('calc_2', dot) + + def test_render_none_labels(self): + """Nodes with label=None should render using node_id as fallback.""" + try: + import graphviz as gv_mod + except ImportError: + self.skipTest('graphviz not installed') + from arc.plotter import render_provenance_graph + g = ProvenanceGraph(project='none_label_test') + g.add_species_node(label=None) + g.add_calculation_node(label=None, job_name='opt', status='pending') + gv = render_provenance_graph(g, run_label='test') + dot = gv.source + # Should not crash; node_id is used as fallback for species + self.assertIn('species_1', dot) + + +if __name__ == '__main__': + unittest.main() diff --git a/arc/scheduler.py b/arc/scheduler.py index 80ee664eae..5ddaea92ae 100644 --- a/arc/scheduler.py +++ b/arc/scheduler.py @@ -57,6 +57,7 @@ ) from arc.species.perceive import perceive_molecule_from_xyz from arc.species.vectors import get_angle, calculate_dihedral_angle +from arc.provenance import (ProvenanceGraph, EdgeType, NodeType, DecisionKind) if TYPE_CHECKING: from arc.job.adapter import JobAdapter @@ -304,6 +305,8 @@ def __init__(self, 'events': list(), } self.provenance_path = os.path.join(self.project_directory, 'output', 'provenance.yml') + self.graph = ProvenanceGraph(project=self.project, run_id=self.provenance['run_id']) + self.graph_path = os.path.join(self.project_directory, 'output', 'provenance_graph.yml') self.species_dict, self.rxn_dict = dict(), dict() for species in self.species_list: @@ -382,6 +385,7 @@ def __init__(self, label=ts_species.label, is_ts=True, ) + self.graph.add_species_node(label=ts_species.label, is_ts=True) else: # The TS species was already loaded from a restart dict or an Arkane YAML file. ts_species = None @@ -611,14 +615,22 @@ def _initialize_provenance(self): self.provenance['events'] = raw_events else: logger.warning('Existing provenance.yml has invalid events; starting with an empty event log.') + if os.path.isfile(self.graph_path): + try: + self.graph = ProvenanceGraph.load(self.graph_path) + except Exception: + logger.warning('Could not parse existing provenance_graph.yml; starting a fresh graph.') already_initialized = {e['label'] for e in self.provenance['events'] if e.get('event_type') == 'species_initialized' and isinstance(e.get('label'), str)} + already_in_graph = {n.label for n in self.graph.get_nodes_by_type(NodeType.species)} for species in self.species_list: if species.label not in already_initialized: self.record_provenance_event(event_type='species_initialized', label=species.label, is_ts=species.is_ts, ) + if species.label not in already_in_graph: + self.graph.add_species_node(label=species.label, is_ts=species.is_ts) def record_provenance_event(self, event_type: str, @@ -640,17 +652,23 @@ def record_provenance_event(self, self.save_provenance() def save_provenance(self): - """Persist the provenance event log.""" + """Persist the provenance event log. The graph is saved lazily via save_provenance_graph().""" output_directory = os.path.dirname(self.provenance_path) if not os.path.isdir(output_directory): os.makedirs(output_directory) save_yaml_file(path=self.provenance_path, content=self.provenance) + def save_provenance_graph(self): + """Persist the provenance graph to disk. Called at checkpoints and finalization, not per-event.""" + self.graph.save(self.graph_path) + def finalize_provenance(self): """Render final provenance artifacts after the run completes.""" self.provenance['ended_at'] = datetime.datetime.now().isoformat(timespec='seconds') + self.graph.save(self.graph_path) plotter.save_provenance_artifacts(project_directory=self.project_directory, provenance=self.provenance, + graph=self.graph, ) def schedule_jobs(self): @@ -1096,6 +1114,26 @@ def run_job(self, provenance_parent_job=provenance_parent_job, provenance_reason=provenance_reason, ) + # ── Graph: add CalculationNode ── + calc_node_id = self.graph.add_calculation_node( + label=provenance_label, + job_name=job.job_name, + job_type=job.job_type, + job_adapter=job.job_adapter, + level=level_repr, + status='pending', + conformer=conformer, + tsg=tsg, + ess_trsh_methods=job.ess_trsh_methods if job.ess_trsh_methods else None, + ) + species_node_id = self.graph.find_species_node(provenance_label) + if species_node_id is not None: + self.graph.add_edge(species_node_id, calc_node_id, EdgeType.belongs_to) + if provenance_parent_job: + parent_node_id = self.graph.find_calc_node(provenance_label, provenance_parent_job) + if parent_node_id is not None: + edge_type = EdgeType.fine_of if provenance_reason == 'fine_opt' else EdgeType.retried_as + self.graph.add_edge(parent_node_id, calc_node_id, edge_type) job.execute() self.save_restart_dict() @@ -1216,6 +1254,7 @@ def end_job(self, job: 'JobAdapter', self.timer = False job.write_completed_job_to_csv_file() logger.info(f' Ending job {job_name} for {label} (run time: {job.run_time})') + job_status_str = job.job_status[1]['status'] if job.job_status[1]['status'] else job.job_status[0] self.record_provenance_event( event_type='job_finished', label=label, @@ -1223,11 +1262,18 @@ def end_job(self, job: 'JobAdapter', job_key=f'{label}:{job.job_name}', job_name=job.job_name, job_type=job.job_type, - status=job.job_status[1]['status'] if job.job_status[1]['status'] else job.job_status[0], + status=job_status_str, keywords=job.job_status[1]['keywords'], error=job.job_status[1]['error'], run_time=str(job.run_time) if job.run_time is not None else None, ) + # ── Graph: update CalculationNode status ── + prov_label = '+'.join(label) if isinstance(label, list) else label + calc_nid = self.graph.find_calc_node(prov_label, job.job_name) + if calc_nid is not None: + self.graph.update_node(calc_nid, + status=job_status_str, + run_time=str(job.run_time) if job.run_time is not None else None) if job.job_status[0] != 'done': return False if job.job_adapter in ['gaussian', 'terachem'] and os.path.isfile(os.path.join(job.local_path, 'check.chk')) \ @@ -1355,7 +1401,16 @@ def run_ts_conformer_jobs(self, label: str): Args: label (str): The TS species label. """ - self.species_dict[label].cluster_tsgs() + cluster_summary = self.species_dict[label].cluster_tsgs() + if cluster_summary is not None and cluster_summary['n_before'] > cluster_summary['n_after']: + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_guess_clustering, + criteria={'n_before': cluster_summary['n_before'], + 'n_after': cluster_summary['n_after'], + 'merged': cluster_summary['merged']}, + outcome=f'Clustered {cluster_summary["n_before"]} into {cluster_summary["n_after"]} unique guesses', + ) plotter.save_conformers_file( project_directory=self.project_directory, label=label, @@ -1377,18 +1432,15 @@ def run_ts_conformer_jobs(self, label: str): if not piped_indices: self.job_dict[label]['conf_opt'] = dict() for i, tsg in enumerate(successful_tsgs): - tsg.conformer_index = i # Store the conformer index to match them later. - if i in piped_indices: - continue - if 'conf_opt' not in self.job_dict[label]: - self.job_dict[label]['conf_opt'] = dict() - self.job_dict[label]['conf_opt'] = dict() - for tsg in successful_tsgs: if tsg.index is None: existing_indices = [guess.index for guess in self.species_dict[label].ts_guesses if guess.index is not None] tsg.index = max(existing_indices or [-1]) + 1 - tsg.conformer_index = tsg.index # Set before run_job so restart state is consistent. + tsg.conformer_index = tsg.index + if i in piped_indices: + continue + if 'conf_opt' not in self.job_dict[label]: + self.job_dict[label]['conf_opt'] = dict() self.run_job(label=label, xyz=tsg.initial_xyz, level_of_theory=self.ts_guess_level, @@ -1821,6 +1873,7 @@ def spawn_ts_jobs(self): else: rxn.ts_species.tsg_spawned = True tsg_index = 0 + spawned_methods = [] for method in self.ts_adapters: if method in all_families_ts_adapters or \ (rxn.family is not None @@ -1832,7 +1885,21 @@ def spawn_ts_jobs(self): reactions=[rxn], tsg=tsg_index, ) + spawned_methods.append(method) tsg_index += 1 + # ── Graph: record TS method spawning decision ── + if spawned_methods: + dec_nid = self.graph.add_decision_node( + label=rxn.ts_label, + decision_kind=DecisionKind.ts_method_spawning, + criteria={'family': rxn.family, + 'all_adapters': list(self.ts_adapters), + 'spawned': spawned_methods}, + outcome=f'Spawned {len(spawned_methods)} TS guess methods', + ) + spc_nid = self.graph.find_species_node(rxn.ts_label) + if spc_nid is not None: + self.graph.add_edge(spc_nid, dec_nid, EdgeType.triggered_by) if all('user guess' in tsg.method for tsg in rxn.ts_species.ts_guesses): rxn.ts_species.tsg_spawned = True self.run_conformer_jobs(labels=[rxn.ts_label]) @@ -2384,6 +2451,17 @@ def determine_most_stable_conformer(self, label, sp_flag=False): self.output[label]['job_types']['conf_opt'] = True if sp_flag: self.output[label]['job_types']['conf_sp'] = True + # ── Graph: record conformer selection decision ── + selected_idx = xyzs_in_original_order.index(conformer_xyz) + non_none_energies = [(i, e) for i, e in enumerate( + self.species_dict[label].conformer_energies) if e is not None] + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.conformer_selection, + criteria={'n_conformers': len(non_none_energies), + 'isomorphic': self.species_dict[label].conf_is_isomorphic}, + outcome=f'Selected conformer #{selected_idx}', + ) def determine_most_likely_ts_conformer(self, label: str): """ @@ -2393,7 +2471,16 @@ def determine_most_likely_ts_conformer(self, label: str): Args: label (str): The TS species label. """ - self.species_dict[label].cluster_tsgs() + cluster_summary = self.species_dict[label].cluster_tsgs() + if cluster_summary is not None and cluster_summary['n_before'] > cluster_summary['n_after']: + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_guess_clustering, + criteria={'n_before': cluster_summary['n_before'], + 'n_after': cluster_summary['n_after'], + 'merged': cluster_summary['merged']}, + outcome=f'Clustered {cluster_summary["n_before"]} into {cluster_summary["n_after"]} unique guesses', + ) if not self.species_dict[label].is_ts: raise SchedulerError('determine_most_likely_ts_conformer() method only processes transition state guesses.') if not self.species_dict[label].successful_methods: @@ -2443,6 +2530,11 @@ def determine_most_likely_ts_conformer(self, label: str): label=label, is_ts=True, ) + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_guess_selection_failed, + outcome='No viable TS guess found', + ) return None else: rxn_txt = '' if self.species_dict[label].rxn_label is None \ @@ -2469,6 +2561,12 @@ def determine_most_likely_ts_conformer(self, label: str): method=tsg.method, energy=tsg.energy, ) + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_guess_selection, + criteria={'selected_index': selected_i, 'energy': tsg.energy}, + outcome=f'Selected TSGuess #{selected_i} via {tsg.method}', + ) if tsg.success and tsg.energy is not None: # guess method and ts_level opt were both successful tsg.energy -= e_min im_freqs = f', imaginary frequencies {tsg.imaginary_freqs}' if tsg.imaginary_freqs is not None else '' @@ -2758,6 +2856,12 @@ def check_freq_job(self, logger.info(f'TS {label} did not pass the normal mode displacement check. ' f'Status is:\n{self.species_dict[label].ts_checks}\n' f'Searching for a better TS conformer...') + # ── Graph: record NMD validation failure ── + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_validation_nmd, + outcome='Failed: normal mode displacement check', + ) self.switch_ts(label) switch_ts = True if wrong_freq_message in self.output[label]['warnings']: @@ -2825,10 +2929,25 @@ def check_negative_freq(self, logger.info(f'TS {label} did not pass the negative frequency check. ' f'Status is:\n{self.species_dict[label].ts_checks}\n' f'Searching for a better TS conformer...') + # ── Graph: record TS freq validation failure ── + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_validation_freq, + criteria={'neg_freqs': [float(f) for f in neg_freqs], + 'expected': 1}, + outcome=f'Failed: {len(neg_freqs)} imaginary freqs, switching TS', + ) self.switch_ts(label=label) return False else: logger.info(f'TS {label} has exactly one imaginary frequency: {neg_freqs[0]}') + # ── Graph: record TS freq validation pass ── + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_validation_freq, + criteria={'neg_freqs': [float(neg_freqs[0])]}, + outcome='Passed: exactly 1 imaginary frequency', + ) self.output[label]['info'] += f'Imaginary frequency: {neg_freqs[0] if len(neg_freqs) == 1 else neg_freqs}; ' self.output[label]['job_types']['freq'] = True self.output[label]['paths']['freq'] = job.local_path_to_output_file @@ -2898,7 +3017,18 @@ def switch_ts(self, label: str): label (str): The TS species label. """ logger.info(f'Switching a TS guess for {label}...') + old_chosen = self.species_dict[label].chosen_ts self.determine_most_likely_ts_conformer(label=label) # Look for a different TS guess. + new_chosen = self.species_dict[label].chosen_ts + # ── Graph: record TS switch decision ── + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_switch, + criteria={'old_chosen': old_chosen, 'new_chosen': new_chosen, + 'exhausted': self.species_dict[label].ts_guesses_exhausted}, + outcome=f'Switched from TSG #{old_chosen} to #{new_chosen}' + if new_chosen is not None else 'All TS guesses exhausted', + ) self.delete_all_species_jobs(label=label) # Delete other currently running jobs for this TS. freq_path = os.path.join(self.project_directory, 'output', 'rxns', label, 'geometry', 'freq.out') if os.path.isfile(freq_path): @@ -3070,10 +3200,18 @@ def check_irc_species(self, label: str): if len(self.output[ts_label]['paths']['irc']) == 2: irc_species_labels = self.species_dict[ts_label].irc_label.split() if all(self.output[irc_label]['paths']['geo'] for irc_label in irc_species_labels): - check_irc_species_and_rxn(xyz_1=self.output[irc_species_labels[0]]['paths']['geo'], - xyz_2=self.output[irc_species_labels[1]]['paths']['geo'], - rxn=self.rxn_dict.get(self.species_dict[ts_label].rxn_index, None), - ) + check_irc_species_and_rxn( + xyz_1=self.output[irc_species_labels[0]]['paths']['geo'], + xyz_2=self.output[irc_species_labels[1]]['paths']['geo'], + rxn=self.rxn_dict.get(self.species_dict[ts_label].rxn_index, None), + ) + # ── Graph: record IRC validation decision ── + self.graph.add_decision_node( + label=ts_label, + decision_kind=DecisionKind.ts_validation_irc, + criteria={'irc_species': irc_species_labels}, + outcome='IRC validation completed', + ) def check_scan_job(self, label: str, @@ -3331,18 +3469,6 @@ def check_all_done(self, label: str): logger.debug(f'Species {label} did not converge.') all_converged = False break - for job_type, spawn_job_type in self.job_types.items(): - if spawn_job_type and not self.output[label]['job_types'][job_type] \ - and not ((self.species_dict[label].is_ts and job_type in ['scan', 'conf_opt']) - or (self.species_dict[label].number_of_atoms == 1 - and job_type in ['conf_opt', 'opt', 'fine', 'freq', 'rotors', 'bde']) - or job_type == 'bde' and self.species_dict[label].bdes is None - or job_type == 'conf_opt' - or job_type == 'irc' - or job_type == 'tsg'): - logger.debug(f'Species {label} did not converge.') - all_converged = False - break if all_converged and self._missing_required_paths(label): logger.debug(f'Species {label} did not converge due to missing output paths.') all_converged = False @@ -3764,7 +3890,14 @@ def troubleshoot_ess(self, f'log file:\n"{job.job_status[1]["line"]}".' logger.warning(warning_message) if self.species_dict[label].is_ts and conformer is not None: - xyz = self.species_dict[label].ts_guesses[conformer].get_xyz() + tsg = next((t for t in self.species_dict[label].ts_guesses + if t.index == conformer or t.conformer_index == conformer), None) + if tsg is not None: + xyz = tsg.get_xyz() + else: + logger.warning(f'Could not find TS guess with index {conformer} for {label}; ' + f'falling back to species xyz.') + xyz = self.species_dict[label].final_xyz or self.species_dict[label].initial_xyz elif conformer is not None: xyz = self.species_dict[label].conformers[conformer] else: @@ -3816,6 +3949,17 @@ def troubleshoot_ess(self, keywords=job.job_status[1]['keywords'], error=job.job_status[1]['error'], ) + # ── Graph: record troubleshooting decision ── + trsh_dec_nid = self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.job_troubleshooting, + criteria={'error_keywords': job.job_status[1]['keywords'], + 'error': job.job_status[1]['error']}, + outcome=f'Retrying with {", ".join(ess_trsh_methods[-1:])}' if ess_trsh_methods else 'Retrying', + ) + failed_calc_nid = self.graph.find_calc_node(label, job.job_name) + if failed_calc_nid is not None: + self.graph.add_edge(failed_calc_nid, trsh_dec_nid, EdgeType.troubleshot_by) self.run_job(label=label, xyz=xyz, level_of_theory=level_of_theory, @@ -3839,6 +3983,10 @@ def troubleshoot_ess(self, and conformer is None: # Only switch TS guess when a full optimization fails, not when a single # conformer search job fails. Other conformers may still be running. + logger.info(f'TS {label} did not converge. ' + f'Status is:\n{self.species_dict[label].ts_checks}\n' + f'Searching for a better TS conformer...') + self.switch_ts(label=label) elif self.species_dict[label].is_ts and not self.species_dict[label].ts_guesses_exhausted: # During TS conf_opt screening, avoid switching mid-batch since switch_ts() deletes all # running jobs for this TS label and can discard other viable TS guesses still running. @@ -4040,8 +4188,9 @@ def save_restart_dict(self): for job_name in self.running_jobs[spc.label] if 'conf_sp' in job_name] \ + [self.job_dict[spc.label]['tsg'][get_i_from_job_name(job_name)].as_dict() for job_name in self.running_jobs[spc.label] if 'tsg' in job_name] - save_yaml_file(path=self.restart_path, content=self.restart_dict) - + save_yaml_file(path=self.restart_path, content=self.restart_dict) + self.save_provenance_graph() + def make_reaction_labels_info_file(self): """ A helper function for creating the `reactions labels.info` file. diff --git a/arc/scheduler_test.py b/arc/scheduler_test.py index c0667f7f59..44ac2442b0 100644 --- a/arc/scheduler_test.py +++ b/arc/scheduler_test.py @@ -912,6 +912,59 @@ def test_provenance_multi_species_label(self): self.assertEqual(event['label'], 'H2+O2') self.assertIsInstance(event['label'], str) + def test_provenance_graph_species_initialized(self): + """Test that the provenance graph contains species nodes after initialization.""" + spc = ARCSpecies(label='water', smiles='O') + project_directory = os.path.join(ARC_PATH, 'Projects', 'arc_project_for_testing_delete_after_usage_prov_graph') + os.makedirs(os.path.join(project_directory, 'output'), exist_ok=True) + sched = Scheduler(project='test_prov_graph', ess_settings=self.ess_settings, + species_list=[spc], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=project_directory, + testing=True, job_types=initialize_job_types()) + from arc.provenance import NodeType + species_nodes = sched.graph.get_nodes_by_type(NodeType.species) + self.assertEqual(len(species_nodes), 1) + self.assertEqual(species_nodes[0].label, 'water') + # Graph is saved lazily (at checkpoints/finalization, not per-event). + # Verify it can be saved on demand. + sched.save_provenance_graph() + self.assertTrue(os.path.isfile(sched.graph_path)) + shutil.rmtree(project_directory, ignore_errors=True) + + def test_provenance_graph_restart_preserves_nodes(self): + """Test that the provenance graph is restored correctly on restart.""" + spc = ARCSpecies(label='methane', smiles='C') + project_directory = os.path.join(ARC_PATH, 'Projects', 'arc_project_for_testing_delete_after_usage_prov_graph2') + os.makedirs(os.path.join(project_directory, 'output'), exist_ok=True) + # Create initial scheduler to write provenance files + sched1 = Scheduler(project='test_restart', ess_settings=self.ess_settings, + species_list=[spc], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=project_directory, + testing=True, job_types=initialize_job_types()) + n_nodes_before = len(sched1.graph) + self.assertGreater(n_nodes_before, 0) + sched1.save_provenance_graph() # Persist graph so the restart can load it. + # Create second scheduler on same directory (simulates restart) + sched2 = Scheduler(project='test_restart', ess_settings=self.ess_settings, + species_list=[spc], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=project_directory, + testing=True, job_types=initialize_job_types()) + from arc.provenance import NodeType + species_nodes = sched2.graph.get_nodes_by_type(NodeType.species) + # Should still have exactly 1 species node (no duplicate) + self.assertEqual(len(species_nodes), 1) + self.assertEqual(species_nodes[0].label, 'methane') + shutil.rmtree(project_directory, ignore_errors=True) + @classmethod def tearDownClass(cls): """ @@ -919,7 +972,9 @@ def tearDownClass(cls): Delete all project directories created during these unit tests """ projects = ['arc_project_for_testing_delete_after_usage3', 'arc_project_for_testing_delete_after_usage6', - 'arc_project_for_testing_delete_after_usage_prov'] + 'arc_project_for_testing_delete_after_usage_prov', + 'arc_project_for_testing_delete_after_usage_prov_graph', + 'arc_project_for_testing_delete_after_usage_prov_graph2'] for project in projects: project_directory = os.path.join(ARC_PATH, 'Projects', project) shutil.rmtree(project_directory, ignore_errors=True) diff --git a/arc/species/species.py b/arc/species/species.py index f9cacbf7b0..485651eb4c 100644 --- a/arc/species/species.py +++ b/arc/species/species.py @@ -1553,6 +1553,11 @@ def make_ts_report(self): def cluster_tsgs(self): """ Cluster TSGuesses. + + Returns: + Optional[dict]: ``None`` if this species is not a TS or has no TS guesses. + Otherwise a summary dict with keys ``n_before``, ``n_after``, and + ``merged`` (list of lists of merged indices). """ if not self.is_ts or not len(self.ts_guesses): return None @@ -1574,6 +1579,9 @@ def cluster_tsgs(self): if len(cluster_tsgs) < n_before: logger.info(f'Clustered {n_before} TS guesses for {self.label} ' f'into {len(cluster_tsgs)} unique conformers.') + return {'n_before': n_before, + 'n_after': len(cluster_tsgs), + 'merged': [tsg.cluster for tsg in cluster_tsgs if len(tsg.cluster) > 1]} def process_completed_tsg_queue_jobs(self, path: str): """ From 5a3febb39ae2b7a5995cefabe9dd4b503269f3be Mon Sep 17 00:00:00 2001 From: Calvin Pieters Date: Sun, 12 Apr 2026 18:13:43 +0300 Subject: [PATCH 11/14] Refining --- .gitignore | 7 ++++ arc/job/pipe/pipe_coordinator.py | 21 ++++++++++ arc/plotter.py | 69 ++++++++++++++++++++++++++++---- arc/scheduler.py | 39 +++++++++++++++--- 4 files changed, 124 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 9132dc1f86..90f8505630 100644 --- a/.gitignore +++ b/.gitignore @@ -70,3 +70,10 @@ build/* # AI Agent files AGENTS.md +CLAUDE.md + +# Provenance related +provenance.yml +provenance_graph.yml +provenance.svg +*.dot diff --git a/arc/job/pipe/pipe_coordinator.py b/arc/job/pipe/pipe_coordinator.py index b5a0d874e8..3c9631758c 100644 --- a/arc/job/pipe/pipe_coordinator.py +++ b/arc/job/pipe/pipe_coordinator.py @@ -274,13 +274,16 @@ def ingest_pipe_results(self, pipe: PipeRun) -> None: if state.status == TaskState.COMPLETED.value: ingest_completed_task(pipe.run_id, pipe.pipe_root, spec, state, self.sched.species_dict, self.sched.output) + self._update_graph_for_pipe_task(spec, status='done') elif state.status == TaskState.FAILED_ESS.value: self._eject_to_scheduler(pipe, spec, state) + self._update_graph_for_pipe_task(spec, status='errored') ejected_count += 1 elif state.status == TaskState.FAILED_TERMINAL.value: logger.error(f'Pipe run {pipe.run_id}, task {spec.task_id}: ' f'failed terminally (failure_class={state.failure_class}). ' f'Manual troubleshooting required.') + self._update_graph_for_pipe_task(spec, status='errored') elif state.status == TaskState.CANCELLED.value: logger.warning(f'Pipe run {pipe.run_id}, task {spec.task_id}: ' f'was cancelled.') @@ -290,6 +293,24 @@ def ingest_pipe_results(self, pipe: PipeRun) -> None: else: self._post_ingest_pipe_run(pipe) + def _update_graph_for_pipe_task(self, spec: TaskSpec, status: str) -> None: + """Update the provenance graph calc node for a completed/failed pipe task.""" + graph = getattr(self.sched, 'graph', None) + if graph is None: + return + label = spec.owner_key + meta = spec.ingestion_metadata or {} + job_type = TASK_FAMILY_TO_JOB_TYPE.get(spec.task_family, spec.task_family) + # Build the job_name the scheduler would have used for this task. + conf_idx = meta.get('conformer_index') + if conf_idx is not None: + job_name = f'{job_type}_{conf_idx}' + else: + job_name = spec.task_id # fallback to pipe task_id + calc_nid = graph.find_calc_node(label, job_name) + if calc_nid is not None: + graph.update_node(calc_nid, status=status) + def _post_ingest_pipe_run(self, pipe: PipeRun) -> None: """ Trigger family-specific post-processing after all tasks in a pipe run diff --git a/arc/plotter.py b/arc/plotter.py index 1d8df116e9..c25d84549c 100644 --- a/arc/plotter.py +++ b/arc/plotter.py @@ -134,7 +134,40 @@ def render_provenance_graph(prov_graph, run_label: str = 'ARC run') -> 'graphviz 'spawned_by': {'color': 'blue', 'style': 'solid'}, } + # ── Identify conf_opt batches to collapse ────────────────────────────── + # Group conf_opt calc nodes by species label for batch summarization. + _COLLAPSE_THRESHOLD = 5 # only collapse if more than this many conf_opt jobs + conf_opt_groups = {} # label -> list of node_ids + conf_opt_collapsed = set() # node_ids that will be replaced by a summary for node in prov_graph.nodes.values(): + if (node.node_type == 'calculation' + and getattr(node, 'job_type', '') == 'conf_opt'): + conf_opt_groups.setdefault(node.label, []).append(node.node_id) + batch_summary_ids = {} # label -> summary_node_graphviz_id + for label, nids in conf_opt_groups.items(): + if len(nids) > _COLLAPSE_THRESHOLD: + conf_opt_collapsed.update(nids) + statuses = [getattr(prov_graph.get_node(n), 'status', 'pending') or 'pending' for n in nids] + done = statuses.count('done') + errored = statuses.count('errored') + pending = len(statuses) - done - errored + parts = [] + if done: + parts.append(f'{done} done') + if errored: + parts.append(f'{errored} errored') + if pending: + parts.append(f'{pending} pending') + summary_id = _sanitize_graphviz_id(f'batch_conf_opt_{label}') + batch_summary_ids[label] = summary_id + gv.node(summary_id, + _wrap_graph_label(f'conf_opt batch\n{len(nids)} jobs\n{", ".join(parts)}', width=28), + shape='box3d', fillcolor='lightyellow', style='filled') + + # ── Render individual nodes ────────────────────────────────────────── + for node in prov_graph.nodes.values(): + if node.node_id in conf_opt_collapsed: + continue # replaced by batch summary nid = _sanitize_graphviz_id(node.node_id) ntype = node.node_type @@ -146,12 +179,11 @@ def render_provenance_graph(prov_graph, run_label: str = 'ARC run') -> 'graphviz gv.node(nid, _wrap_graph_label(lbl), shape='box', fillcolor='aliceblue') elif ntype == 'calculation': - parts = [getattr(node, 'job_type', '') or '', getattr(node, 'job_name', '') or ''] + job_type = getattr(node, 'job_type', '') or '' + job_name = getattr(node, 'job_name', '') or '' + lbl = f'{job_type}\n{job_name}' if getattr(node, 'job_adapter', None): - parts.append(node.job_adapter) - if getattr(node, 'level', None): - parts.append(node.level) - lbl = '\n'.join(p for p in parts if p) + lbl += f'\n{node.job_adapter}' status = getattr(node, 'status', 'pending') or 'pending' fillcolor = _calc_colors.get(status, 'white') gv.node(nid, _wrap_graph_label(lbl), shape='box', fillcolor=fillcolor) @@ -162,6 +194,9 @@ def render_provenance_graph(prov_graph, run_label: str = 'ARC run') -> 'graphviz lbl = dk if val is not None and not isinstance(val, (list, dict)): lbl += f'\n{val}' + meta = getattr(node, 'metadata', None) or {} + if 'n_imaginary' in meta: + lbl += f'\n({meta["n_imaginary"]} imag)' gv.node(nid, _wrap_graph_label(lbl), shape='note', fillcolor='cornsilk') elif ntype == 'decision': @@ -176,11 +211,31 @@ def render_provenance_graph(prov_graph, run_label: str = 'ARC run') -> 'graphviz else: gv.node(nid, _wrap_graph_label(node.node_id)) + # ── Render edges ───────────────────────────────────────────────────── + # Track which batch summaries have been connected to avoid duplicate edges. + batch_edges_added = set() for edge in prov_graph.edges: - src = _sanitize_graphviz_id(edge.source_id) - tgt = _sanitize_graphviz_id(edge.target_id) + src_collapsed = edge.source_id in conf_opt_collapsed + tgt_collapsed = edge.target_id in conf_opt_collapsed + # Redirect edges involving collapsed conf_opt nodes to the batch summary. + if src_collapsed: + src_label = prov_graph.get_node(edge.source_id).label if prov_graph.get_node(edge.source_id) else None + src = batch_summary_ids.get(src_label, _sanitize_graphviz_id(edge.source_id)) + else: + src = _sanitize_graphviz_id(edge.source_id) + if tgt_collapsed: + tgt_label = prov_graph.get_node(edge.target_id).label if prov_graph.get_node(edge.target_id) else None + tgt = batch_summary_ids.get(tgt_label, _sanitize_graphviz_id(edge.target_id)) + else: + tgt = _sanitize_graphviz_id(edge.target_id) + # Deduplicate edges to/from batch summaries. + edge_key = (src, tgt, edge.edge_type) + if (src_collapsed or tgt_collapsed) and edge_key in batch_edges_added: + continue + batch_edges_added.add(edge_key) etype = edge.edge_type style_attrs = _edge_styles.get(etype, {}) + # Only show labels on semantically interesting edges (not belongs_to, input_of, output_of). label = etype.replace('_', ' ') if etype not in ('belongs_to', 'input_of', 'output_of') else '' gv.edge(src, tgt, label=label, **style_attrs) diff --git a/arc/scheduler.py b/arc/scheduler.py index 5ddaea92ae..18ee3cef8e 100644 --- a/arc/scheduler.py +++ b/arc/scheduler.py @@ -57,7 +57,7 @@ ) from arc.species.perceive import perceive_molecule_from_xyz from arc.species.vectors import get_angle, calculate_dihedral_angle -from arc.provenance import (ProvenanceGraph, EdgeType, NodeType, DecisionKind) +from arc.provenance import (ProvenanceGraph, EdgeType, NodeType, DataKind, DecisionKind) if TYPE_CHECKING: from arc.job.adapter import JobAdapter @@ -385,7 +385,8 @@ def __init__(self, label=ts_species.label, is_ts=True, ) - self.graph.add_species_node(label=ts_species.label, is_ts=True) + if self.graph.find_species_node(ts_species.label) is None: + self.graph.add_species_node(label=ts_species.label, is_ts=True) else: # The TS species was already loaded from a restart dict or an Arkane YAML file. ts_species = None @@ -602,8 +603,13 @@ def _flush_pending_pipe_conf_sp(self) -> None: self.run_sp_job(label=label, level=self.conformer_sp_level, conformer=i) def _initialize_provenance(self): - """Load previous provenance when restarting and record the current run start.""" - if os.path.isfile(self.provenance_path): + """Load previous provenance when restarting and record the current run start. + + On a fresh run (no restart_dict), the event log and graph start empty. + On a restart, the previous event log and graph are loaded and deduplicated. + """ + is_restart = self.restart_dict is not None + if is_restart and os.path.isfile(self.provenance_path): try: provenance = read_yaml_file(self.provenance_path) except Exception: @@ -615,7 +621,7 @@ def _initialize_provenance(self): self.provenance['events'] = raw_events else: logger.warning('Existing provenance.yml has invalid events; starting with an empty event log.') - if os.path.isfile(self.graph_path): + if is_restart and os.path.isfile(self.graph_path): try: self.graph = ProvenanceGraph.load(self.graph_path) except Exception: @@ -2282,6 +2288,13 @@ def parse_conformer(self, logger.debug(f'Energy for conformer {i} of {label} is {energy:.2f}') else: logger.debug(f'Energy for conformer {i} of {label} is None') + # ── Graph: emit energy DataNode from conformer job ── + if energy is not None: + calc_nid = self.graph.find_calc_node(label, job.job_name) + if calc_nid is not None: + data_nid = self.graph.add_data_node( + label=label, data_kind=DataKind.energy, value=round(energy, 2)) + self.graph.add_edge(calc_nid, data_nid, EdgeType.output_of) else: logger.warning(f'Conformer {i} for {label} did not converge.') if job.job_status[1]['status'] == 'errored' and job.times_rerun == 0: @@ -2830,6 +2843,14 @@ def check_freq_job(self, freq_ok = self.check_negative_freq(label=label, job=job, vibfreqs=vibfreqs) if freq_ok and vibfreqs is not None: self.species_dict[label].freqs = [float(f) for f in vibfreqs] + # ── Graph: emit frequencies DataNode ── + calc_nid = self.graph.find_calc_node(label, job.job_name) + if calc_nid is not None: + data_nid = self.graph.add_data_node( + label=label, data_kind=DataKind.frequencies, + value=len(vibfreqs), + metadata={'n_imaginary': sum(1 for f in vibfreqs if f < 0)}) + self.graph.add_edge(calc_nid, data_nid, EdgeType.output_of) if freq_ok: # Copy the frequency file to the species / TS output folder. folder_name = 'rxns' if self.species_dict[label].is_ts else 'Species' @@ -3062,6 +3083,14 @@ def check_sp_job(self, sp_path=os.path.join(job.local_path_to_output_file), level=job.level, ) + # ── Graph: emit SP energy DataNode ── + if self.species_dict[label].e_elect is not None: + calc_nid = self.graph.find_calc_node(label, job.job_name) + if calc_nid is not None: + data_nid = self.graph.add_data_node( + label=label, data_kind=DataKind.energy, + value=round(self.species_dict[label].e_elect, 2)) + self.graph.add_edge(calc_nid, data_nid, EdgeType.output_of) # Update restart dictionary and save the yaml restart file: self.save_restart_dict() if self.species_dict[label].number_of_atoms == 1: From 272fc55bc0cc7f3c9439031ccec0e6d889aafe00 Mon Sep 17 00:00:00 2001 From: Calvin Pieters Date: Sun, 12 Apr 2026 19:11:08 +0300 Subject: [PATCH 12/14] Emitting SP node when LoT for geo == sp Added also TS troubleshoots --- arc/scheduler.py | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/arc/scheduler.py b/arc/scheduler.py index 18ee3cef8e..df52dbb53b 100644 --- a/arc/scheduler.py +++ b/arc/scheduler.py @@ -1596,6 +1596,16 @@ def run_sp_job(self, ) else: raise RuntimeError(f'Unable to set the path for the sp job for species {label}') + # ── Graph: emit energy DataNode from opt log (sp_level == opt_level) ── + if self.species_dict[label].e_elect is not None: + opt_calc_nid = self.graph.find_calc_node(label, recent_opt_job_name) \ + if recent_opt_job is not None else None + if opt_calc_nid is not None: + data_nid = self.graph.add_data_node( + label=label, data_kind=DataKind.energy, + value=round(self.species_dict[label].e_elect, 2), + metadata={'source': 'opt_log', 'note': 'SP energy parsed from opt output'}) + self.graph.add_edge(opt_calc_nid, data_nid, EdgeType.output_of) return if 'sp' not in self.job_dict[label].keys(): @@ -3672,6 +3682,16 @@ def troubleshoot_negative_freq(self, logger.info(f'Deleting all currently running jobs for species {label} before troubleshooting for ' f'negative frequency with perturbed conformers...') logger.info(f'conformers:') + # ── Graph: record negative freq troubleshooting ── + trsh_nid = self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.job_troubleshooting, + criteria={'type': 'negative_freq', 'n_conformers': len(confs)}, + outcome=f'Generated {len(confs)} perturbed conformers', + ) + freq_calc_nid = self.graph.find_calc_node(label, job.job_name) + if freq_calc_nid is not None: + self.graph.add_edge(freq_calc_nid, trsh_nid, EdgeType.troubleshot_by) self.delete_all_species_jobs(label) self.species_dict[label].conformers = confs self.species_dict[label].conformer_energies = [None] * len(confs) @@ -3822,6 +3842,18 @@ def troubleshoot_scan_job(self, trsh={'scan_res': scan_res} if scan_res is not None else None, rotor_index=job.rotor_index, ) + # ── Graph: record scan troubleshooting decision ── + if trsh_success: + label = job.species_label + trsh_nid = self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.job_troubleshooting, + criteria={'type': 'scan', 'actions': actual_actions}, + outcome=f'Scan troubleshooting: {", ".join(str(k) for k in actual_actions)}', + ) + scan_calc_nid = self.graph.find_calc_node(label, job.job_name) + if scan_calc_nid is not None: + self.graph.add_edge(scan_calc_nid, trsh_nid, EdgeType.troubleshot_by) return trsh_success, actual_actions def troubleshoot_opt_jobs(self, label): @@ -4067,7 +4099,13 @@ def troubleshoot_conformer_isomorphism(self, label: str): 'graph representation!; ' else: logger.info(f'Troubleshooting conformer job in {job.job_adapter} using {level_of_theory} for species {label}') - + # ── Graph: record conformer isomorphism troubleshooting ── + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.job_troubleshooting, + criteria={'type': 'conformer_isomorphism', 'new_level': str(level_of_theory)}, + outcome=f'Rerunning {num_of_conformers} conformers at {level_of_theory}', + ) # rerun conformer job at higher level for all conformers for conformer in range(0, num_of_conformers): if conformer >= len(self.species_dict[label].conformers_before_opt): From 2b02329560a7235099cc7c436dec78d0801a296e Mon Sep 17 00:00:00 2001 From: Calvin Pieters Date: Mon, 13 Apr 2026 17:30:53 +0300 Subject: [PATCH 13/14] Updates --- arc/main.py | 65 ++++++++++++++++ arc/plotter.py | 9 ++- arc/provenance/nodes.py | 3 + arc/provenance/provenance_test.py | 4 +- arc/scheduler.py | 125 ++++++++++++++++++++++++++---- 5 files changed, 186 insertions(+), 20 deletions(-) diff --git a/arc/main.py b/arc/main.py index f3e2fc9329..f7be659676 100644 --- a/arc/main.py +++ b/arc/main.py @@ -36,6 +36,7 @@ from arc.job.ssh import SSHClient from arc.output import write_output_yml from arc.processor import process_arc_project, resolve_neb_level +from arc.provenance import DecisionKind, EdgeType from arc.reaction import ARCReaction from arc.scheduler import Scheduler from arc.species.converter import str_to_xyz @@ -671,6 +672,70 @@ def execute(self) -> dict: log_footer(execution_time=self.execution_time) return status_dict + def _add_arkane_provenance_nodes(self): + """Add Arkane computation and result nodes to the provenance graph. + + For each converged species with thermo results, creates: + convergence_confirmed → calc(statmech_thermo) → data(thermo) + + For each converged reaction with kinetics results, creates: + convergence_confirmed → calc(statmech_kinetics) → data(kinetics) + """ + graph = self.scheduler.graph + for spc in self.scheduler.species_dict.values(): + if spc.is_ts or getattr(spc.thermo, 'H298', None) is None: + continue + spc_nid = graph.find_species_node(spc.label) + if spc_nid is None: + continue + # Insert a CalculationNode for the Arkane thermo computation. + calc_nid = graph.add_calculation_node( + label=spc.label, + job_name='arkane_thermo', + job_type='statmech_thermo', + job_adapter=self.thermo_adapter, + status='done', + ) + graph.add_edge(spc_nid, calc_nid, EdgeType.belongs_to) + # Link from convergence gate if it exists. + conv_nodes = graph.query(decision_kind=DecisionKind.convergence_confirmed, label=spc.label) + for conv_node in conv_nodes: + graph.add_edge(conv_node.node_id, calc_nid, EdgeType.triggered_by) + thermo_nid = graph.add_data_node( + label=spc.label, + data_kind='thermo', + value=f'H298={spc.thermo.H298:.1f} kJ/mol, S298={spc.thermo.S298:.1f} J/mol/K', + ) + graph.add_edge(calc_nid, thermo_nid, EdgeType.output_of) + for rxn in self.scheduler.rxn_list: + if rxn.kinetics is None: + continue + ts_nid = graph.find_species_node(rxn.ts_label) + if ts_nid is None: + continue + # Insert a CalculationNode for the Arkane kinetics computation. + calc_nid = graph.add_calculation_node( + label=rxn.ts_label, + job_name='arkane_kinetics', + job_type='statmech_kinetics', + job_adapter=self.kinetics_adapter, + status='done', + ) + graph.add_edge(ts_nid, calc_nid, EdgeType.belongs_to) + # Link from TS convergence gate if it exists. + conv_nodes = graph.query(decision_kind=DecisionKind.convergence_confirmed, label=rxn.ts_label) + for conv_node in conv_nodes: + graph.add_edge(conv_node.node_id, calc_nid, EdgeType.triggered_by) + ea = rxn.kinetics.get('Ea') + ea_str = f', Ea={ea[0]:.1f} {ea[1]}' if ea else '' + kinetics_nid = graph.add_data_node( + label=rxn.ts_label, + data_kind='kinetics', + value=f'{rxn.label}{ea_str}', + ) + graph.add_edge(calc_nid, kinetics_nid, EdgeType.output_of) + graph.save(self.scheduler.graph_path) + def save_project_info_file(self): """ Save a project info file. diff --git a/arc/plotter.py b/arc/plotter.py index c25d84549c..ed790380f7 100644 --- a/arc/plotter.py +++ b/arc/plotter.py @@ -120,7 +120,10 @@ def render_provenance_graph(prov_graph, run_label: str = 'ARC run') -> 'graphviz 'ts_validation_freq': 'lightyellow', 'ts_validation_nmd': 'lightyellow', 'ts_validation_irc': 'lightyellow', + 'ts_validation_e0': 'lightyellow', + 'ts_validation_e_elect': 'lightyellow', 'ts_switch': 'mistyrose', + 'convergence_confirmed': 'palegreen', } # Edge styling lookup @@ -235,8 +238,10 @@ def render_provenance_graph(prov_graph, run_label: str = 'ARC run') -> 'graphviz batch_edges_added.add(edge_key) etype = edge.edge_type style_attrs = _edge_styles.get(etype, {}) - # Only show labels on semantically interesting edges (not belongs_to, input_of, output_of). - label = etype.replace('_', ' ') if etype not in ('belongs_to', 'input_of', 'output_of') else '' + # Only show labels on semantically interesting edges; suppress purely structural ones. + _suppress = ('belongs_to', 'input_of', 'output_of', 'retried_as') + _rename = {'fine_of': 'fine grid', 'troubleshot_by': 'troubleshoot'} + label = _rename.get(etype, etype.replace('_', ' ')) if etype not in _suppress else '' gv.edge(src, tgt, label=label, **style_attrs) return gv diff --git a/arc/provenance/nodes.py b/arc/provenance/nodes.py index bebe43a8cc..d65cdb374b 100644 --- a/arc/provenance/nodes.py +++ b/arc/provenance/nodes.py @@ -55,9 +55,12 @@ class DecisionKind(str, Enum): ts_validation_freq = 'ts_validation_freq' ts_validation_nmd = 'ts_validation_nmd' ts_validation_irc = 'ts_validation_irc' + ts_validation_e0 = 'ts_validation_e0' + ts_validation_e_elect = 'ts_validation_e_elect' ts_switch = 'ts_switch' job_troubleshooting = 'job_troubleshooting' ts_method_spawning = 'ts_method_spawning' + convergence_confirmed = 'convergence_confirmed' class EdgeType(str, Enum): diff --git a/arc/provenance/provenance_test.py b/arc/provenance/provenance_test.py index 0db5aecef3..060e999143 100644 --- a/arc/provenance/provenance_test.py +++ b/arc/provenance/provenance_test.py @@ -33,7 +33,9 @@ def test_data_kind_values(self): def test_decision_kind_values(self): expected = {'conformer_selection', 'ts_guess_clustering', 'ts_guess_selection', 'ts_guess_selection_failed', 'ts_validation_freq', 'ts_validation_nmd', - 'ts_validation_irc', 'ts_switch', 'job_troubleshooting', 'ts_method_spawning'} + 'ts_validation_irc', 'ts_validation_e0', 'ts_validation_e_elect', + 'ts_switch', 'job_troubleshooting', 'ts_method_spawning', + 'convergence_confirmed'} actual = {dk.value for dk in DecisionKind} self.assertEqual(expected, actual) diff --git a/arc/scheduler.py b/arc/scheduler.py index df52dbb53b..14d29eecaf 100644 --- a/arc/scheduler.py +++ b/arc/scheduler.py @@ -1408,8 +1408,9 @@ def run_ts_conformer_jobs(self, label: str): label (str): The TS species label. """ cluster_summary = self.species_dict[label].cluster_tsgs() + cluster_nid = None if cluster_summary is not None and cluster_summary['n_before'] > cluster_summary['n_after']: - self.graph.add_decision_node( + cluster_nid = self.graph.add_decision_node( label=label, decision_kind=DecisionKind.ts_guess_clustering, criteria={'n_before': cluster_summary['n_before'], @@ -1417,6 +1418,10 @@ def run_ts_conformer_jobs(self, label: str): 'merged': cluster_summary['merged']}, outcome=f'Clustered {cluster_summary["n_before"]} into {cluster_summary["n_after"]} unique guesses', ) + # Connect species → clustering decision. + spc_nid = self.graph.find_species_node(label) + if spc_nid is not None: + self.graph.add_edge(spc_nid, cluster_nid, EdgeType.triggered_by) plotter.save_conformers_file( project_directory=self.project_directory, label=label, @@ -1903,7 +1908,7 @@ def spawn_ts_jobs(self): ) spawned_methods.append(method) tsg_index += 1 - # ── Graph: record TS method spawning decision ── + # ── Graph: record TS method spawning decision and connect to tsg calc nodes ── if spawned_methods: dec_nid = self.graph.add_decision_node( label=rxn.ts_label, @@ -1916,6 +1921,11 @@ def spawn_ts_jobs(self): spc_nid = self.graph.find_species_node(rxn.ts_label) if spc_nid is not None: self.graph.add_edge(spc_nid, dec_nid, EdgeType.triggered_by) + # Connect spawning decision → each tsg calc node. + for i in range(tsg_index): + tsg_calc_nid = self.graph.find_calc_node(rxn.ts_label, f'tsg{i}') + if tsg_calc_nid is not None: + self.graph.add_edge(dec_nid, tsg_calc_nid, EdgeType.spawned_by) if all('user guess' in tsg.method for tsg in rxn.ts_species.ts_guesses): rxn.ts_species.tsg_spawned = True self.run_conformer_jobs(labels=[rxn.ts_label]) @@ -2496,7 +2506,7 @@ def determine_most_likely_ts_conformer(self, label: str): """ cluster_summary = self.species_dict[label].cluster_tsgs() if cluster_summary is not None and cluster_summary['n_before'] > cluster_summary['n_after']: - self.graph.add_decision_node( + cluster_nid = self.graph.add_decision_node( label=label, decision_kind=DecisionKind.ts_guess_clustering, criteria={'n_before': cluster_summary['n_before'], @@ -2504,6 +2514,9 @@ def determine_most_likely_ts_conformer(self, label: str): 'merged': cluster_summary['merged']}, outcome=f'Clustered {cluster_summary["n_before"]} into {cluster_summary["n_after"]} unique guesses', ) + spc_nid = self.graph.find_species_node(label) + if spc_nid is not None: + self.graph.add_edge(spc_nid, cluster_nid, EdgeType.triggered_by) if not self.species_dict[label].is_ts: raise SchedulerError('determine_most_likely_ts_conformer() method only processes transition state guesses.') if not self.species_dict[label].successful_methods: @@ -2553,11 +2566,14 @@ def determine_most_likely_ts_conformer(self, label: str): label=label, is_ts=True, ) - self.graph.add_decision_node( + fail_nid = self.graph.add_decision_node( label=label, decision_kind=DecisionKind.ts_guess_selection_failed, outcome='No viable TS guess found', ) + spc_nid = self.graph.find_species_node(label) + if spc_nid is not None: + self.graph.add_edge(spc_nid, fail_nid, EdgeType.triggered_by) return None else: rxn_txt = '' if self.species_dict[label].rxn_label is None \ @@ -2584,12 +2600,20 @@ def determine_most_likely_ts_conformer(self, label: str): method=tsg.method, energy=tsg.energy, ) - self.graph.add_decision_node( + sel_nid = self.graph.add_decision_node( label=label, decision_kind=DecisionKind.ts_guess_selection, criteria={'selected_index': selected_i, 'energy': tsg.energy}, outcome=f'Selected TSGuess #{selected_i} via {tsg.method}', ) + # Connect: conf_opt → selected_by → selection decision. + conf_opt_nid = self.graph.find_calc_node(label, f'conf_opt{selected_i}') + if conf_opt_nid is not None: + self.graph.add_edge(conf_opt_nid, sel_nid, EdgeType.selected_by) + # Connect: selection → species (so subsequent opt flows from it). + spc_nid = self.graph.find_species_node(label) + if spc_nid is not None: + self.graph.add_edge(sel_nid, spc_nid, EdgeType.output_of) if tsg.success and tsg.energy is not None: # guess method and ts_level opt were both successful tsg.energy -= e_min im_freqs = f', imaginary frequencies {tsg.imaginary_freqs}' if tsg.imaginary_freqs is not None else '' @@ -2888,13 +2912,26 @@ def check_freq_job(self, f'Status is:\n{self.species_dict[label].ts_checks}\n' f'Searching for a better TS conformer...') # ── Graph: record NMD validation failure ── - self.graph.add_decision_node( + nmd_fail_nid = self.graph.add_decision_node( label=label, decision_kind=DecisionKind.ts_validation_nmd, outcome='Failed: normal mode displacement check', ) - self.switch_ts(label) + freq_calc_nid = self.graph.find_calc_node(label, job.job_name) + if freq_calc_nid is not None: + self.graph.add_edge(freq_calc_nid, nmd_fail_nid, EdgeType.output_of) + self.switch_ts(label, triggered_by_nid=nmd_fail_nid) switch_ts = True + elif self.species_dict[label].ts_checks['NMD'] is True: + # ── Graph: record NMD validation pass ── + nmd_pass_nid = self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_validation_nmd, + outcome='Passed: normal mode displacement check', + ) + freq_calc_nid = self.graph.find_calc_node(label, job.job_name) + if freq_calc_nid is not None: + self.graph.add_edge(freq_calc_nid, nmd_pass_nid, EdgeType.output_of) if wrong_freq_message in self.output[label]['warnings']: self.output[label]['warnings'] = ''.join(self.output[label]['warnings'].split(wrong_freq_message)) elif not self.species_dict[label].is_ts and self.trsh_ess_jobs: @@ -2961,24 +2998,30 @@ def check_negative_freq(self, f'Status is:\n{self.species_dict[label].ts_checks}\n' f'Searching for a better TS conformer...') # ── Graph: record TS freq validation failure ── - self.graph.add_decision_node( + freq_fail_nid = self.graph.add_decision_node( label=label, decision_kind=DecisionKind.ts_validation_freq, criteria={'neg_freqs': [float(f) for f in neg_freqs], 'expected': 1}, outcome=f'Failed: {len(neg_freqs)} imaginary freqs, switching TS', ) - self.switch_ts(label=label) + freq_calc_nid = self.graph.find_calc_node(label, job.job_name) + if freq_calc_nid is not None: + self.graph.add_edge(freq_calc_nid, freq_fail_nid, EdgeType.output_of) + self.switch_ts(label=label, triggered_by_nid=freq_fail_nid) return False else: logger.info(f'TS {label} has exactly one imaginary frequency: {neg_freqs[0]}') # ── Graph: record TS freq validation pass ── - self.graph.add_decision_node( + freq_pass_nid = self.graph.add_decision_node( label=label, decision_kind=DecisionKind.ts_validation_freq, criteria={'neg_freqs': [float(neg_freqs[0])]}, outcome='Passed: exactly 1 imaginary frequency', ) + freq_calc_nid = self.graph.find_calc_node(label, job.job_name) + if freq_calc_nid is not None: + self.graph.add_edge(freq_calc_nid, freq_pass_nid, EdgeType.output_of) self.output[label]['info'] += f'Imaginary frequency: {neg_freqs[0] if len(neg_freqs) == 1 else neg_freqs}; ' self.output[label]['job_types']['freq'] = True self.output[label]['paths']['freq'] = job.local_path_to_output_file @@ -3030,7 +3073,13 @@ def check_rxn_e0_by_spc(self, label: str): if rxn.ts_species.ts_checks['E0'] is False: logger.info(f'TS {rxn.ts_species.label} of reaction {rxn.label} did not pass the E0 check.\n' f'Searching for a better TS conformer...\n') - self.switch_ts(rxn.ts_label) + # ── Graph: record E0 validation failure ── + e0_fail_nid = self.graph.add_decision_node( + label=rxn.ts_label, + decision_kind=DecisionKind.ts_validation_e0, + outcome=f'Failed: TS E0 not above both wells for {rxn.label}', + ) + self.switch_ts(rxn.ts_label, triggered_by_nid=e0_fail_nid) if self.species_dict[rxn.ts_label].ts_guesses_exhausted \ or self.species_dict[rxn.ts_label].chosen_ts is None: logger.warning(f'Could not find a valid TS conformer for {rxn.ts_label} ' @@ -3039,20 +3088,41 @@ def check_rxn_e0_by_spc(self, label: str): # Restore E0 failure flag — switch_ts resets ts_checks via populate_ts_checks(). # check_all_done reads this to avoid overwriting convergence back to True. self.species_dict[rxn.ts_label].ts_checks['E0'] = False + elif rxn.ts_species.ts_checks['E0'] is True: + # ── Graph: record E0 validation pass ── + self.graph.add_decision_node( + label=rxn.ts_label, + decision_kind=DecisionKind.ts_validation_e0, + outcome=f'Passed: TS E0 above both wells for {rxn.label}', + ) + # Also record e_elect check if it ran as a fallback. + if rxn.ts_species.ts_checks['e_elect'] is True: + self.graph.add_decision_node( + label=rxn.ts_label, + decision_kind=DecisionKind.ts_validation_e_elect, + outcome=f'Passed: TS e_elect above both wells for {rxn.label}', + ) + elif rxn.ts_species.ts_checks['e_elect'] is False: + self.graph.add_decision_node( + label=rxn.ts_label, + decision_kind=DecisionKind.ts_validation_e_elect, + outcome=f'Warning: TS e_elect not above both wells for {rxn.label}', + ) - def switch_ts(self, label: str): + def switch_ts(self, label: str, triggered_by_nid: Optional[str] = None): """ Try the next optimized TS guess in line if a previous TS guess was found to be wrong. Args: label (str): The TS species label. + triggered_by_nid (str, optional): Node ID of the validation decision that triggered this switch. """ logger.info(f'Switching a TS guess for {label}...') old_chosen = self.species_dict[label].chosen_ts self.determine_most_likely_ts_conformer(label=label) # Look for a different TS guess. new_chosen = self.species_dict[label].chosen_ts # ── Graph: record TS switch decision ── - self.graph.add_decision_node( + switch_nid = self.graph.add_decision_node( label=label, decision_kind=DecisionKind.ts_switch, criteria={'old_chosen': old_chosen, 'new_chosen': new_chosen, @@ -3060,6 +3130,8 @@ def switch_ts(self, label: str): outcome=f'Switched from TSG #{old_chosen} to #{new_chosen}' if new_chosen is not None else 'All TS guesses exhausted', ) + if triggered_by_nid is not None: + self.graph.add_edge(triggered_by_nid, switch_nid, EdgeType.triggered_by) self.delete_all_species_jobs(label=label) # Delete other currently running jobs for this TS. freq_path = os.path.join(self.project_directory, 'output', 'rxns', label, 'geometry', 'freq.out') if os.path.isfile(freq_path): @@ -3239,17 +3311,25 @@ def check_irc_species(self, label: str): if len(self.output[ts_label]['paths']['irc']) == 2: irc_species_labels = self.species_dict[ts_label].irc_label.split() if all(self.output[irc_label]['paths']['geo'] for irc_label in irc_species_labels): + rxn = self.rxn_dict.get(self.species_dict[ts_label].rxn_index, None) check_irc_species_and_rxn( xyz_1=self.output[irc_species_labels[0]]['paths']['geo'], xyz_2=self.output[irc_species_labels[1]]['paths']['geo'], - rxn=self.rxn_dict.get(self.species_dict[ts_label].rxn_index, None), + rxn=rxn, ) - # ── Graph: record IRC validation decision ── + # ── Graph: record IRC validation decision with actual outcome ── + irc_result = rxn.ts_species.ts_checks['IRC'] if rxn is not None else None + if irc_result is True: + irc_outcome = 'Passed: IRC endpoints match expected R/P' + elif irc_result is False: + irc_outcome = 'Failed: IRC endpoints do not match expected R/P' + else: + irc_outcome = 'Inconclusive: could not determine IRC match' self.graph.add_decision_node( label=ts_label, decision_kind=DecisionKind.ts_validation_irc, - criteria={'irc_species': irc_species_labels}, - outcome='IRC validation completed', + criteria={'irc_species': irc_species_labels, 'result': irc_result}, + outcome=irc_outcome, ) def check_scan_job(self, @@ -3512,7 +3592,18 @@ def check_all_done(self, label: str): logger.debug(f'Species {label} did not converge due to missing output paths.') all_converged = False if label in self.output and all_converged: + already_converged = self.output[label]['convergence'] self.output[label]['convergence'] = True + # ── Graph: record convergence gate (only once per species) ── + if not already_converged: + conv_nid = self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.convergence_confirmed, + outcome='All required calculations converged', + ) + spc_nid = self.graph.find_species_node(label) + if spc_nid is not None: + self.graph.add_edge(spc_nid, conv_nid, EdgeType.output_of) if self.species_dict[label].is_ts: self.species_dict[label].make_ts_report() logger.info(self.species_dict[label].ts_report + '\n') From 8cc3172982ffeec65508f200bb5636c837f99956 Mon Sep 17 00:00:00 2001 From: Calvin Pieters Date: Mon, 13 Apr 2026 20:31:28 +0300 Subject: [PATCH 14/14] update --- arc/scheduler.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/arc/scheduler.py b/arc/scheduler.py index 14d29eecaf..a369a2e7f5 100644 --- a/arc/scheduler.py +++ b/arc/scheduler.py @@ -307,6 +307,7 @@ def __init__(self, self.provenance_path = os.path.join(self.project_directory, 'output', 'provenance.yml') self.graph = ProvenanceGraph(project=self.project, run_id=self.provenance['run_id']) self.graph_path = os.path.join(self.project_directory, 'output', 'provenance_graph.yml') + self._pending_ts_selection_nid: Dict[str, str] = dict() # label → selection decision node ID self.species_dict, self.rxn_dict = dict(), dict() for species in self.species_list: @@ -1140,6 +1141,10 @@ def run_job(self, if parent_node_id is not None: edge_type = EdgeType.fine_of if provenance_reason == 'fine_opt' else EdgeType.retried_as self.graph.add_edge(parent_node_id, calc_node_id, edge_type) + # Wire selection decision → opt job for TS species. + if job_type in ('opt', 'composite') and provenance_label in self._pending_ts_selection_nid: + sel_nid = self._pending_ts_selection_nid.pop(provenance_label) + self.graph.add_edge(sel_nid, calc_node_id, EdgeType.triggered_by) job.execute() self.save_restart_dict() @@ -2610,10 +2615,8 @@ def determine_most_likely_ts_conformer(self, label: str): conf_opt_nid = self.graph.find_calc_node(label, f'conf_opt{selected_i}') if conf_opt_nid is not None: self.graph.add_edge(conf_opt_nid, sel_nid, EdgeType.selected_by) - # Connect: selection → species (so subsequent opt flows from it). - spc_nid = self.graph.find_species_node(label) - if spc_nid is not None: - self.graph.add_edge(sel_nid, spc_nid, EdgeType.output_of) + # Stash the selection node so the next opt job can link back to it. + self._pending_ts_selection_nid[label] = sel_nid if tsg.success and tsg.energy is not None: # guess method and ts_level opt were both successful tsg.energy -= e_min im_freqs = f', imaginary frequencies {tsg.imaginary_freqs}' if tsg.imaginary_freqs is not None else '' @@ -3132,6 +3135,9 @@ def switch_ts(self, label: str, triggered_by_nid: Optional[str] = None): ) if triggered_by_nid is not None: self.graph.add_edge(triggered_by_nid, switch_nid, EdgeType.triggered_by) + # Connect switch → the new selection decision (if a new guess was found). + if label in self._pending_ts_selection_nid: + self.graph.add_edge(switch_nid, self._pending_ts_selection_nid[label], EdgeType.triggered_by) self.delete_all_species_jobs(label=label) # Delete other currently running jobs for this TS. freq_path = os.path.join(self.project_directory, 'output', 'rxns', label, 'geometry', 'freq.out') if os.path.isfile(freq_path):