diff --git a/.gitignore b/.gitignore index 9132dc1f86..90f8505630 100644 --- a/.gitignore +++ b/.gitignore @@ -70,3 +70,10 @@ build/* # AI Agent files AGENTS.md +CLAUDE.md + +# Provenance related +provenance.yml +provenance_graph.yml +provenance.svg +*.dot diff --git a/arc/job/pipe/pipe_coordinator.py b/arc/job/pipe/pipe_coordinator.py index b5a0d874e8..3c9631758c 100644 --- a/arc/job/pipe/pipe_coordinator.py +++ b/arc/job/pipe/pipe_coordinator.py @@ -274,13 +274,16 @@ def ingest_pipe_results(self, pipe: PipeRun) -> None: if state.status == TaskState.COMPLETED.value: ingest_completed_task(pipe.run_id, pipe.pipe_root, spec, state, self.sched.species_dict, self.sched.output) + self._update_graph_for_pipe_task(spec, status='done') elif state.status == TaskState.FAILED_ESS.value: self._eject_to_scheduler(pipe, spec, state) + self._update_graph_for_pipe_task(spec, status='errored') ejected_count += 1 elif state.status == TaskState.FAILED_TERMINAL.value: logger.error(f'Pipe run {pipe.run_id}, task {spec.task_id}: ' f'failed terminally (failure_class={state.failure_class}). ' f'Manual troubleshooting required.') + self._update_graph_for_pipe_task(spec, status='errored') elif state.status == TaskState.CANCELLED.value: logger.warning(f'Pipe run {pipe.run_id}, task {spec.task_id}: ' f'was cancelled.') @@ -290,6 +293,24 @@ def ingest_pipe_results(self, pipe: PipeRun) -> None: else: self._post_ingest_pipe_run(pipe) + def _update_graph_for_pipe_task(self, spec: TaskSpec, status: str) -> None: + """Update the provenance graph calc node for a completed/failed pipe task.""" + graph = getattr(self.sched, 'graph', None) + if graph is None: + return + label = spec.owner_key + meta = spec.ingestion_metadata or {} + job_type = TASK_FAMILY_TO_JOB_TYPE.get(spec.task_family, spec.task_family) + # Build the job_name the scheduler would have used for this task. + conf_idx = meta.get('conformer_index') + if conf_idx is not None: + job_name = f'{job_type}_{conf_idx}' + else: + job_name = spec.task_id # fallback to pipe task_id + calc_nid = graph.find_calc_node(label, job_name) + if calc_nid is not None: + graph.update_node(calc_nid, status=status) + def _post_ingest_pipe_run(self, pipe: PipeRun) -> None: """ Trigger family-specific post-processing after all tasks in a pipe run diff --git a/arc/plotter.py b/arc/plotter.py index d0f6938e84..c25d84549c 100644 --- a/arc/plotter.py +++ b/arc/plotter.py @@ -2,6 +2,7 @@ A module for plotting and saving output files such as RMG libraries. """ +import datetime import matplotlib # Force matplotlib to not use any Xwindows backend. # This must be called before pylab, matplotlib.pyplot, or matplotlib.backends is imported. @@ -12,10 +13,15 @@ import numpy as np import os import shutil +import textwrap from matplotlib.backends.backend_pdf import PdfPages from mpl_toolkits.mplot3d import Axes3D from typing import List, Optional, Tuple, Union +try: + import graphviz +except ImportError: + graphviz = None import py3Dmol as p3D from rdkit import Chem @@ -54,6 +60,360 @@ logger = get_logger() +def _sanitize_graphviz_id(value: str) -> str: + """Return a Graphviz-safe identifier.""" + return ''.join(ch if ch.isalnum() else '_' for ch in value) + + +def _wrap_graph_label(text: str, width: int = 24) -> str: + """Wrap long labels so graph nodes stay readable, preserving intentional newlines.""" + if not text: + return '' + return '\n'.join(line for part in str(text).split('\n') + for line in (textwrap.wrap(part, width=width) or [''])) + + +def render_provenance_graph(prov_graph, run_label: str = 'ARC run') -> 'graphviz.Digraph': + """ + Render a :class:`ProvenanceGraph` as a Graphviz directed graph. + + Node styling by type: + - **species**: box / aliceblue + - **calculation**: box / color by status (honeydew=done, mistyrose=errored, white=pending) + - **data**: note / cornsilk + - **decision**: diamond / color by kind (lavender, moccasin, mistyrose) + + Edge styling by type: + - ``selected_by``: solid green + - ``rejected_by``: dashed red + - ``troubleshot_by``: dashed orange + - ``retried_as`` / ``fine_of``: dotted gray + - others: solid black + + Args: + prov_graph: A :class:`ProvenanceGraph` instance. + run_label (str): Label for the root run node. + + Returns: + graphviz.Digraph: The rendered graph object. + """ + if graphviz is None: + raise ImportError('The graphviz Python package is required for render_provenance_graph(). ' + 'Install it with: conda install -c conda-forge python-graphviz') + gv = graphviz.Digraph( + name='arc_provenance', + comment=f'ARC provenance for {run_label}', + graph_attr={'rankdir': 'LR', 'splines': 'true', 'overlap': 'false'}, + node_attr={'shape': 'box', 'style': 'rounded,filled', 'fillcolor': 'white', 'fontname': 'Helvetica'}, + edge_attr={'fontname': 'Helvetica'}, + ) + + # Node styling lookup + _calc_colors = {'done': 'honeydew', 'errored': 'mistyrose', 'pending': 'white'} + _decision_colors = { + 'ts_guess_selection': 'lavender', + 'ts_guess_selection_failed': 'mistyrose', + 'job_troubleshooting': 'moccasin', + 'conformer_selection': 'lavender', + 'ts_guess_clustering': 'lavender', + 'ts_method_spawning': 'lavender', + 'ts_validation_freq': 'lightyellow', + 'ts_validation_nmd': 'lightyellow', + 'ts_validation_irc': 'lightyellow', + 'ts_switch': 'mistyrose', + } + + # Edge styling lookup + _edge_styles = { + 'selected_by': {'color': 'green3', 'style': 'solid'}, + 'rejected_by': {'color': 'red', 'style': 'dashed'}, + 'troubleshot_by': {'color': 'orange', 'style': 'dashed'}, + 'triggered_by': {'color': 'gray40', 'style': 'solid'}, + 'retried_as': {'color': 'gray60', 'style': 'dotted'}, + 'fine_of': {'color': 'gray60', 'style': 'dotted'}, + 'spawned_by': {'color': 'blue', 'style': 'solid'}, + } + + # ── Identify conf_opt batches to collapse ────────────────────────────── + # Group conf_opt calc nodes by species label for batch summarization. + _COLLAPSE_THRESHOLD = 5 # only collapse if more than this many conf_opt jobs + conf_opt_groups = {} # label -> list of node_ids + conf_opt_collapsed = set() # node_ids that will be replaced by a summary + for node in prov_graph.nodes.values(): + if (node.node_type == 'calculation' + and getattr(node, 'job_type', '') == 'conf_opt'): + conf_opt_groups.setdefault(node.label, []).append(node.node_id) + batch_summary_ids = {} # label -> summary_node_graphviz_id + for label, nids in conf_opt_groups.items(): + if len(nids) > _COLLAPSE_THRESHOLD: + conf_opt_collapsed.update(nids) + statuses = [getattr(prov_graph.get_node(n), 'status', 'pending') or 'pending' for n in nids] + done = statuses.count('done') + errored = statuses.count('errored') + pending = len(statuses) - done - errored + parts = [] + if done: + parts.append(f'{done} done') + if errored: + parts.append(f'{errored} errored') + if pending: + parts.append(f'{pending} pending') + summary_id = _sanitize_graphviz_id(f'batch_conf_opt_{label}') + batch_summary_ids[label] = summary_id + gv.node(summary_id, + _wrap_graph_label(f'conf_opt batch\n{len(nids)} jobs\n{", ".join(parts)}', width=28), + shape='box3d', fillcolor='lightyellow', style='filled') + + # ── Render individual nodes ────────────────────────────────────────── + for node in prov_graph.nodes.values(): + if node.node_id in conf_opt_collapsed: + continue # replaced by batch summary + nid = _sanitize_graphviz_id(node.node_id) + ntype = node.node_type + + if ntype == 'species': + lbl = node.label or node.node_id + is_ts = (node.metadata or {}).get('is_ts', False) + if is_ts: + lbl += '\nTS' + gv.node(nid, _wrap_graph_label(lbl), shape='box', fillcolor='aliceblue') + + elif ntype == 'calculation': + job_type = getattr(node, 'job_type', '') or '' + job_name = getattr(node, 'job_name', '') or '' + lbl = f'{job_type}\n{job_name}' + if getattr(node, 'job_adapter', None): + lbl += f'\n{node.job_adapter}' + status = getattr(node, 'status', 'pending') or 'pending' + fillcolor = _calc_colors.get(status, 'white') + gv.node(nid, _wrap_graph_label(lbl), shape='box', fillcolor=fillcolor) + + elif ntype == 'data': + dk = getattr(node, 'data_kind', '') or '' + val = getattr(node, 'value', None) + lbl = dk + if val is not None and not isinstance(val, (list, dict)): + lbl += f'\n{val}' + meta = getattr(node, 'metadata', None) or {} + if 'n_imaginary' in meta: + lbl += f'\n({meta["n_imaginary"]} imag)' + gv.node(nid, _wrap_graph_label(lbl), shape='note', fillcolor='cornsilk') + + elif ntype == 'decision': + dk = getattr(node, 'decision_kind', '') or '' + outcome = getattr(node, 'outcome', '') or '' + lbl = dk.replace('_', ' ') + if outcome: + lbl += f'\n{outcome}' + fillcolor = _decision_colors.get(dk, 'lavender') + gv.node(nid, _wrap_graph_label(lbl, width=28), shape='diamond', fillcolor=fillcolor) + + else: + gv.node(nid, _wrap_graph_label(node.node_id)) + + # ── Render edges ───────────────────────────────────────────────────── + # Track which batch summaries have been connected to avoid duplicate edges. + batch_edges_added = set() + for edge in prov_graph.edges: + src_collapsed = edge.source_id in conf_opt_collapsed + tgt_collapsed = edge.target_id in conf_opt_collapsed + # Redirect edges involving collapsed conf_opt nodes to the batch summary. + if src_collapsed: + src_label = prov_graph.get_node(edge.source_id).label if prov_graph.get_node(edge.source_id) else None + src = batch_summary_ids.get(src_label, _sanitize_graphviz_id(edge.source_id)) + else: + src = _sanitize_graphviz_id(edge.source_id) + if tgt_collapsed: + tgt_label = prov_graph.get_node(edge.target_id).label if prov_graph.get_node(edge.target_id) else None + tgt = batch_summary_ids.get(tgt_label, _sanitize_graphviz_id(edge.target_id)) + else: + tgt = _sanitize_graphviz_id(edge.target_id) + # Deduplicate edges to/from batch summaries. + edge_key = (src, tgt, edge.edge_type) + if (src_collapsed or tgt_collapsed) and edge_key in batch_edges_added: + continue + batch_edges_added.add(edge_key) + etype = edge.edge_type + style_attrs = _edge_styles.get(etype, {}) + # Only show labels on semantically interesting edges (not belongs_to, input_of, output_of). + label = etype.replace('_', ' ') if etype not in ('belongs_to', 'input_of', 'output_of') else '' + gv.edge(src, tgt, label=label, **style_attrs) + + return gv + + +def save_provenance_artifacts(project_directory: str, + provenance: dict, + graph=None, + ) -> dict: + """ + Save provenance YAML and render Graphviz artifacts for an ARC run. + + When a ``graph`` (:class:`ProvenanceGraph`) is provided, the Graphviz + visualization is built from the graph's typed nodes and edges rather + than the flat event list, producing richer diagrams. + + Args: + project_directory (str): The ARC project directory. + provenance (dict): A provenance dictionary with an ``events`` list. + graph: Optional ProvenanceGraph instance for graph-based rendering. + + Returns: + dict: Paths to generated artifacts. + """ + output_directory = os.path.join(project_directory, 'output') + os.makedirs(output_directory, exist_ok=True) + yml_path = os.path.join(output_directory, 'provenance.yml') + dot_path = os.path.join(output_directory, 'provenance.dot') + svg_path = os.path.join(output_directory, 'provenance.svg') + + run_label = provenance.get('project', 'ARC run') + if graphviz is None: + logger.warning('The graphviz Python package is not available, so ARC will only save provenance.yml.') + provenance['updated_at'] = datetime.datetime.now().isoformat(timespec='seconds') + save_yaml_file(path=yml_path, content=provenance) + return {'yml': yml_path, 'dot': None, 'svg': None} + + # Prefer graph-based rendering when a ProvenanceGraph is available. + if graph is not None and len(graph) > 0: + gv_graph = render_provenance_graph(graph, run_label=run_label) + with open(dot_path, 'w') as f: + f.write(gv_graph.source) + try: + svg_data = gv_graph.pipe(format='svg') + except (graphviz.ExecutableNotFound, graphviz.CalledProcessError): + logger.warning('Could not render ARC provenance SVG because Graphviz is not available on this system.') + else: + with open(svg_path, 'wb') as f: + f.write(svg_data) + provenance['updated_at'] = datetime.datetime.now().isoformat(timespec='seconds') + save_yaml_file(path=yml_path, content=provenance) + return {'yml': yml_path, 'dot': dot_path, 'svg': svg_path if os.path.isfile(svg_path) else None} + + # Fallback: event-based rendering (legacy path). + graph = graphviz.Digraph( + name='arc_provenance', + comment=f'ARC provenance for {run_label}', + graph_attr={'rankdir': 'LR', 'splines': 'true', 'overlap': 'false'}, + node_attr={'shape': 'box', 'style': 'rounded,filled', 'fillcolor': 'white', 'fontname': 'Helvetica'}, + edge_attr={'fontname': 'Helvetica'}, + ) + run_node_id = _sanitize_graphviz_id(f"run_{provenance.get('run_id', run_label)}") + run_header = provenance.get('started_at', '') + run_footer = provenance.get('ended_at', '') + run_text = f'{run_label}' + if run_header: + run_text += f'\nstart: {run_header}' + if run_footer: + run_text += f'\nend: {run_footer}' + graph.node(run_node_id, _wrap_graph_label(run_text, width=32), shape='oval', fillcolor='lightgoldenrod1') + + species_nodes = dict() + job_nodes = dict() + # Track the most recent decision node (troubleshoot / TS selection) per label, + # so that follow-up jobs spawned by that decision connect from the diamond. + last_decision_by_label = dict() + + for event in provenance.get('events', list()): + event_type = event.get('event_type', '') + label = event.get('label') + if label and label not in species_nodes: + species_node_id = _sanitize_graphviz_id(f'species_{label}') + species_text = label + if event.get('is_ts'): + species_text += '\nTS' + graph.node(species_node_id, _wrap_graph_label(species_text), fillcolor='aliceblue') + graph.edge(run_node_id, species_node_id) + species_nodes[label] = species_node_id + + if event_type == 'job_started': + job_key = event.get('job_key', event.get('job_name', 'job')) + job_node_id = _sanitize_graphviz_id(f'job_{job_key}') + job_text = f"{event.get('job_type', 'job')}\n{event.get('job_name', job_key)}" + if event.get('job_adapter'): + job_text += f"\n{event['job_adapter']}" + if event.get('level'): + job_text += f"\n{event['level']}" + graph.node(job_node_id, _wrap_graph_label(job_text), fillcolor='white') + + # Determine the source node for this job's incoming edge. + parent_job = event.get('provenance_parent_job') + reason = event.get('provenance_reason', '') + if parent_job and label in last_decision_by_label: + # A decision (troubleshoot / TS selection) preceded this job — connect from it. + source_node_id = last_decision_by_label.pop(label) + elif parent_job: + # Rerun or other child job — connect from the parent job node. + parent_key = f'{label}:{parent_job}' + source_node_id = job_nodes.get(parent_key, species_nodes.get(label, run_node_id)) + else: + # Normal first-launch job — connect from the species node. + source_node_id = species_nodes.get(label, run_node_id) + graph.edge(source_node_id, job_node_id, label=reason) + job_nodes[job_key] = job_node_id + + elif event_type == 'job_finished': + job_key = event.get('job_key') + if job_key in job_nodes: + status = event.get('status', 'unknown') + fillcolor = {'done': 'honeydew', 'errored': 'mistyrose'}.get(status, 'lightyellow') + graph.node(job_nodes[job_key], fillcolor=fillcolor) + + result_node_id = _sanitize_graphviz_id( + f"result_{event.get('event_id', len(job_nodes))}_{job_key}" + ) + result_text = f"{status}" + if event.get('run_time'): + result_text += f"\n{event['run_time']}" + if event.get('keywords'): + result_text += f"\n{', '.join(event['keywords'])}" + graph.node(result_node_id, _wrap_graph_label(result_text), shape='note', fillcolor='cornsilk') + graph.edge(job_nodes[job_key], result_node_id) + + elif event_type in ('ts_guess_selected', 'ts_guess_selection_failed', 'job_troubleshooting'): + decision_node_id = _sanitize_graphviz_id(f"decision_{event.get('event_id', 0)}") + if event_type == 'ts_guess_selected': + decision_text = f"Select TS guess {event.get('selected_index')}" + if event.get('method'): + decision_text += f"\n{event['method']}" + fillcolor = 'lavender' + elif event_type == 'ts_guess_selection_failed': + decision_text = 'TS guess selection\nfailed' + fillcolor = 'mistyrose' + else: + decision_text = f"Troubleshoot {event.get('job_name', '')}" + if event.get('methods'): + decision_text += f"\n{', '.join(event['methods'])}" + fillcolor = 'moccasin' + graph.node(decision_node_id, _wrap_graph_label(decision_text), shape='diamond', fillcolor=fillcolor) + source_job_key = event.get('job_key') + source_node_id = job_nodes.get(source_job_key) if source_job_key else species_nodes.get(label) + if source_node_id is None and label is not None: + source_node_id = species_nodes.get(label) + if source_node_id is not None: + graph.edge(source_node_id, decision_node_id) + if label is not None: + last_decision_by_label[label] = decision_node_id + + elif event_type == 'species_initialized' and label in species_nodes: + continue + + with open(dot_path, 'w') as f: + f.write(graph.source) + + try: + svg_data = graph.pipe(format='svg') + except (graphviz.ExecutableNotFound, graphviz.CalledProcessError): + logger.warning('Could not render ARC provenance SVG because Graphviz is not available on this system.') + else: + with open(svg_path, 'wb') as f: + f.write(svg_data) + + provenance['updated_at'] = datetime.datetime.now().isoformat(timespec='seconds') + save_yaml_file(path=yml_path, content=provenance) + return {'yml': yml_path, 'dot': dot_path, 'svg': svg_path if os.path.isfile(svg_path) else None} + + # *** Drawings species *** def draw_structure(xyz=None, species=None, project_directory=None, method='show_sticks', show_atom_indices=False): diff --git a/arc/plotter_test.py b/arc/plotter_test.py index ba6984dae4..dea852b275 100644 --- a/arc/plotter_test.py +++ b/arc/plotter_test.py @@ -9,6 +9,11 @@ import shutil import unittest +try: + import graphviz +except ImportError: + graphviz = None + import arc.plotter as plotter from arc.common import ARC_PATH, ARC_TESTING_PATH, read_yaml_file, safe_copy_file from arc.species.converter import str_to_xyz @@ -218,6 +223,139 @@ def test_save_irc_traj_animation(self): plotter.save_irc_traj_animation(irc_f_path, irc_r_path, out_path) self.assertTrue(os.path.isfile(out_path)) + def test_wrap_graph_label(self): + """Test that _wrap_graph_label preserves intentional newlines.""" + # Intentional newlines should be preserved, not collapsed. + result = plotter._wrap_graph_label("opt\nopt_a1\ngaussian\nwb97xd/def2tzvp", width=30) + lines = result.split('\n') + self.assertEqual(lines[0], 'opt') + self.assertEqual(lines[1], 'opt_a1') + self.assertEqual(lines[2], 'gaussian') + self.assertEqual(lines[3], 'wb97xd/def2tzvp') + # Long single lines should still be wrapped. + result = plotter._wrap_graph_label("this is a very long label that should be wrapped", width=20) + self.assertTrue(all(len(line) <= 20 for line in result.split('\n'))) + # Empty string returns empty. + self.assertEqual(plotter._wrap_graph_label(''), '') + + def test_save_provenance_artifacts(self): + """Test saving ARC provenance YAML / Graphviz artifacts.""" + project = 'arc_project_for_testing_delete_after_usage' + project_directory = os.path.join(ARC_PATH, 'Projects', project) + provenance = { + 'project': project, + 'run_id': 'run_1', + 'started_at': '2026-03-15T10:00:00', + 'ended_at': '2026-03-15T10:05:00', + 'events': [ + {'event_id': 1, 'event_type': 'species_initialized', 'timestamp': '2026-03-15T10:00:00', + 'label': 'spc1'}, + {'event_id': 2, 'event_type': 'species_initialized', 'timestamp': '2026-03-15T10:00:00', + 'label': 'TS0', 'is_ts': True}, + {'event_id': 3, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:00:01', + 'label': 'spc1', 'job_key': 'spc1:opt_a1', 'job_name': 'opt_a1', 'job_type': 'opt', + 'job_adapter': 'gaussian', 'level': 'b3lyp/6-31g(d)'}, + {'event_id': 4, 'event_type': 'job_finished', 'timestamp': '2026-03-15T10:01:00', + 'label': 'spc1', 'job_key': 'spc1:opt_a1', 'status': 'done', 'run_time': '0:01:00'}, + {'event_id': 5, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:01:01', + 'label': 'spc1', 'job_key': 'spc1:freq_a2', 'job_name': 'freq_a2', 'job_type': 'freq', + 'job_adapter': 'gaussian', 'level': 'b3lyp/6-31g(d)'}, + {'event_id': 6, 'event_type': 'job_finished', 'timestamp': '2026-03-15T10:01:30', + 'label': 'spc1', 'job_key': 'spc1:freq_a2', 'status': 'errored', + 'run_time': '0:00:30', 'keywords': ['memory']}, + {'event_id': 7, 'event_type': 'job_troubleshooting', 'timestamp': '2026-03-15T10:01:35', + 'label': 'spc1', 'job_key': 'spc1:freq_a2', 'job_name': 'freq_a2', 'job_type': 'freq', + 'methods': ['memory']}, + {'event_id': 8, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:01:40', + 'label': 'spc1', 'job_key': 'spc1:freq_a3', 'job_name': 'freq_a3', 'job_type': 'freq', + 'job_adapter': 'gaussian', 'provenance_parent_job': 'freq_a2', + 'provenance_reason': 'ess_troubleshoot'}, + {'event_id': 9, 'event_type': 'job_finished', 'timestamp': '2026-03-15T10:02:00', + 'label': 'spc1', 'job_key': 'spc1:freq_a3', 'status': 'done', 'run_time': '0:00:20'}, + {'event_id': 10, 'event_type': 'job_started', 'timestamp': '2026-03-15T10:02:01', + 'label': 'TS0', 'job_key': 'TS0:tsg0', 'job_name': 'tsg0', 'job_type': 'tsg', + 'job_adapter': 'autotst'}, + {'event_id': 11, 'event_type': 'job_finished', 'timestamp': '2026-03-15T10:03:00', + 'label': 'TS0', 'job_key': 'TS0:tsg0', 'status': 'done'}, + {'event_id': 12, 'event_type': 'ts_guess_selected', 'timestamp': '2026-03-15T10:03:01', + 'label': 'TS0', 'selected_index': 0, 'method': 'autotst', 'energy': -154.321}, + ], + } + paths = plotter.save_provenance_artifacts(project_directory=project_directory, provenance=provenance) + self.assertTrue(os.path.isfile(paths['yml'])) + if paths['dot'] is not None: + self.assertTrue(os.path.isfile(paths['dot'])) + with open(paths['dot'], 'r') as f: + dot = f.read() + # Species and job nodes are present. + self.assertIn('spc1', dot) + self.assertIn('opt_a1', dot) + self.assertIn('TS0', dot) + # Troubleshoot diamond and edge label rendered. + self.assertIn('Troubleshoot', dot) + self.assertIn('ess_troubleshoot', dot) + # TS guess selection diamond rendered. + self.assertIn('Select TS guess 0', dot) + self.assertIn('autotst', dot) + # Errored job node coloured correctly. + self.assertIn('mistyrose', dot) + # Normal jobs (opt_a1, freq_a2) connect from the species node, not from each other. + self.assertIn('species_spc1 -> job_spc1_opt_a1', dot) + self.assertIn('species_spc1 -> job_spc1_freq_a2', dot) + # Troubleshoot follow-up connects from the decision diamond, not the species node. + self.assertIn('decision_7 -> job_spc1_freq_a3', dot) + + def test_render_provenance_graph(self): + """Test Graphviz rendering from a ProvenanceGraph object.""" + from arc.provenance import (ProvenanceGraph, DecisionKind, DataKind, EdgeType) + g = ProvenanceGraph(project='render_test') + sid = g.add_species_node(label='ethanol') + cid = g.add_calculation_node(label='ethanol', job_name='opt_a1', + job_type='opt', job_adapter='gaussian', + level='b3lyp/6-31g(d)', status='done') + did = g.add_data_node(label='ethanol', data_kind=DataKind.energy, value=-79.5) + dec = g.add_decision_node(label='ethanol', + decision_kind=DecisionKind.conformer_selection, + outcome='Selected conformer #0') + g.add_edge(sid, cid, EdgeType.input_of) + g.add_edge(cid, did, EdgeType.output_of) + g.add_edge(did, dec, EdgeType.selected_by) + + if graphviz is not None: + gv = plotter.render_provenance_graph(g, run_label='render_test') + dot_source = gv.source + self.assertIn('ethanol', dot_source) + self.assertIn('opt', dot_source) + self.assertIn('energy', dot_source) + self.assertIn('conformer selection', dot_source) + self.assertIn('honeydew', dot_source) # done calc + self.assertIn('cornsilk', dot_source) # data node + self.assertIn('diamond', dot_source) # decision node + self.assertIn('green3', dot_source) # selected_by edge + + def test_save_provenance_artifacts_with_graph(self): + """Test that save_provenance_artifacts prefers graph-based rendering when a graph is provided.""" + from arc.provenance import (ProvenanceGraph, DecisionKind, EdgeType) + project = 'arc_project_for_testing_delete_after_usage' + project_directory = os.path.join(ARC_PATH, 'Projects', project) + g = ProvenanceGraph(project=project) + sid = g.add_species_node(label='spc1') + cid = g.add_calculation_node(label='spc1', job_name='opt_a1', + job_type='opt', status='done') + g.add_edge(sid, cid, EdgeType.input_of) + provenance = {'project': project, 'events': []} + paths = plotter.save_provenance_artifacts( + project_directory=project_directory, + provenance=provenance, + graph=g, + ) + self.assertTrue(os.path.isfile(paths['yml'])) + if paths['dot'] is not None: + with open(paths['dot'], 'r') as f: + dot = f.read() + # Graph-based rendering uses node IDs like species_1 not event-based species_spc1. + self.assertIn('species_1', dot) + self.assertIn('honeydew', dot) @classmethod def tearDownClass(cls): diff --git a/arc/provenance/__init__.py b/arc/provenance/__init__.py new file mode 100644 index 0000000000..d6da045f38 --- /dev/null +++ b/arc/provenance/__init__.py @@ -0,0 +1,38 @@ +""" +ARC provenance subpackage — directed acyclic graph for computational provenance. + +Tracks the full chain of inputs, calculations, decisions, and outputs that +produce ARC's results. Inspired by AiiDA's DAG model but adapted for ARC's +branching decision trees (TS guess evaluation, conformer selection, +troubleshooting loops). + +Submodules: + - ``nodes``: Node types, edge types, and their data classes. + - ``graph``: ProvenanceGraph container with query and serialization. +""" + +from arc.provenance.graph import ProvenanceGraph +from arc.provenance.nodes import ( + CalculationNode, + DataKind, + DataNode, + DecisionKind, + DecisionNode, + EdgeType, + NodeType, + ProvenanceEdge, + ProvenanceNode, +) + +__all__ = [ + 'ProvenanceGraph', + 'ProvenanceNode', + 'CalculationNode', + 'DataNode', + 'DecisionNode', + 'ProvenanceEdge', + 'NodeType', + 'DataKind', + 'DecisionKind', + 'EdgeType', +] diff --git a/arc/provenance/graph.py b/arc/provenance/graph.py new file mode 100644 index 0000000000..4ef0b7e33a --- /dev/null +++ b/arc/provenance/graph.py @@ -0,0 +1,366 @@ +""" +ProvenanceGraph — a directed acyclic graph for tracking ARC computational provenance. + +The graph stores typed nodes (species, calculations, data artifacts, decisions) +connected by typed directed edges (input_of, selected_by, troubleshot_by, etc.). +It supports forward/backward traversal, flexible queries, and YAML serialization +via the project's standard ``save_yaml_file`` / ``read_yaml_file`` helpers. +""" + +import datetime +import re +from collections import deque +from typing import Any, Dict, List, Optional + +from arc.common import get_logger, read_yaml_file, save_yaml_file +from arc.provenance.nodes import ( + CalculationNode, + DataNode, + DecisionNode, + NodeType, + ProvenanceEdge, + ProvenanceNode, + _enum_val, +) + +logger = get_logger() + +SCHEMA_VERSION = 2 + + +class ProvenanceGraph(object): + """ + A directed acyclic graph for tracking computational provenance. + + Args: + project (str, optional): The ARC project name. + run_id (str, optional): Unique run identifier. + + Attributes: + nodes (Dict[str, ProvenanceNode]): Maps node_id to node. + edges (List[ProvenanceEdge]): All directed edges. + """ + + def __init__(self, + project: Optional[str] = None, + run_id: Optional[str] = None, + ): + self.project = project + self.run_id = run_id or ( + f'{project}_{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}' + if project else None + ) + self.nodes: Dict[str, ProvenanceNode] = {} + self.edges: List[ProvenanceEdge] = [] + self._counter: int = 0 + + # ── Node operations ────────────────────────────────────────────────────── + + def _next_id(self, prefix: str) -> str: + """Generate the next unique node ID with the given prefix.""" + self._counter += 1 + return f'{prefix}_{self._counter}' + + def add_node(self, node: ProvenanceNode) -> str: + """ + Add a node to the graph. + + Args: + node: The node to add. + + Returns: + str: The node's ID. + """ + if node.node_id in self.nodes: + logger.debug(f'Node {node.node_id!r} already exists in the provenance graph; skipping.') + return node.node_id + self.nodes[node.node_id] = node + return node.node_id + + def add_species_node(self, label: Optional[str] = None, is_ts: bool = False, + timestamp: Optional[str] = None) -> str: + """ + Convenience method to add a species node. + + Args: + label: Species label (optional). + is_ts: Whether this is a transition state. + timestamp: Optional ISO timestamp. + + Returns: + str: The new node's ID. + """ + node_id = self._next_id('species') + metadata = {'is_ts': is_ts} if is_ts else None + node = ProvenanceNode(node_id=node_id, node_type=NodeType.species, + label=label, timestamp=timestamp, metadata=metadata) + self.add_node(node) + return node_id + + def add_calculation_node(self, label: Optional[str] = None, **kwargs) -> str: + """ + Convenience method to add a calculation node. + + Returns: + str: The new node's ID. + """ + node_id = self._next_id('calc') + node = CalculationNode(node_id=node_id, label=label, **kwargs) + self.add_node(node) + return node_id + + def add_data_node(self, label: Optional[str] = None, **kwargs) -> str: + """ + Convenience method to add a data node. + + Returns: + str: The new node's ID. + """ + node_id = self._next_id('data') + node = DataNode(node_id=node_id, label=label, **kwargs) + self.add_node(node) + return node_id + + def add_decision_node(self, label: Optional[str] = None, **kwargs) -> str: + """ + Convenience method to add a decision node. + + Returns: + str: The new node's ID. + """ + node_id = self._next_id('decision') + node = DecisionNode(node_id=node_id, label=label, **kwargs) + self.add_node(node) + return node_id + + def get_node(self, node_id: str) -> Optional[ProvenanceNode]: + """Return the node with the given ID, or None.""" + return self.nodes.get(node_id) + + def get_nodes_by_type(self, node_type: str, + label: Optional[str] = None) -> List[ProvenanceNode]: + """Return all nodes of the given type, optionally filtered by label.""" + results = [n for n in self.nodes.values() if n.node_type == _enum_val(node_type)] + if label is not None: + results = [n for n in results if n.label == label] + return results + + def get_nodes_by_label(self, label: str) -> List[ProvenanceNode]: + """Return all nodes associated with the given species label.""" + return [n for n in self.nodes.values() if n.label == label] + + def find_species_node(self, label: str) -> Optional[str]: + """Return the node_id of the species node for the given label, or None.""" + for n in self.nodes.values(): + if n.node_type == 'species' and n.label == label: + return n.node_id + return None + + def find_calc_node(self, label: str, job_name: str) -> Optional[str]: + """Return the node_id of a calculation node matching label and job_name, or None.""" + for n in self.nodes.values(): + if (n.node_type == 'calculation' + and n.label == label + and getattr(n, 'job_name', None) == job_name): + return n.node_id + return None + + def update_node(self, node_id: str, **attrs) -> bool: + """ + Update attributes on an existing node. + + Args: + node_id: The node to update. + **attrs: Attribute names and new values. + + Returns: + bool: True if the node was found and updated. + """ + node = self.nodes.get(node_id) + if node is None: + return False + for key, value in attrs.items(): + setattr(node, key, value) + return True + + # ── Edge operations ────────────────────────────────────────────────────── + + def add_edge(self, + source_id: str, + target_id: str, + edge_type: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> ProvenanceEdge: + """ + Add a directed edge between two nodes. + + Logs a warning if source or target node does not exist in the graph, + but still creates the edge (the node may be added later on restart). + + Args: + source_id: Source node ID. + target_id: Target node ID. + edge_type: One of :class:`EdgeType` values. + metadata: Optional extra data. + + Returns: + The created edge. + """ + if source_id not in self.nodes: + logger.warning(f'Creating edge from non-existent source node {source_id!r}') + if target_id not in self.nodes: + logger.warning(f'Creating edge to non-existent target node {target_id!r}') + edge = ProvenanceEdge(source_id=source_id, target_id=target_id, + edge_type=edge_type, metadata=metadata) + self.edges.append(edge) + return edge + + def get_edges_from(self, node_id: str) -> List[ProvenanceEdge]: + """Return all edges originating from the given node.""" + return [e for e in self.edges if e.source_id == node_id] + + def get_edges_to(self, node_id: str) -> List[ProvenanceEdge]: + """Return all edges pointing to the given node.""" + return [e for e in self.edges if e.target_id == node_id] + + def get_edges_by_type(self, edge_type: str) -> List[ProvenanceEdge]: + """Return all edges of the given type.""" + return [e for e in self.edges if e.edge_type == _enum_val(edge_type)] + + # ── Traversal ──────────────────────────────────────────────────────────── + + def descendants(self, node_id: str) -> List[str]: + """ + Return all node IDs reachable forward from *node_id* (BFS). + + Does not include *node_id* itself. + """ + visited = set() + queue = deque() + for e in self.edges: + if e.source_id == node_id: + queue.append(e.target_id) + while queue: + nid = queue.popleft() + if nid in visited: + continue + visited.add(nid) + for e in self.edges: + if e.source_id == nid and e.target_id not in visited: + queue.append(e.target_id) + return list(visited) + + def ancestors(self, node_id: str) -> List[str]: + """ + Return all node IDs reachable backward from *node_id* (BFS). + + Does not include *node_id* itself. + """ + visited = set() + queue = deque() + for e in self.edges: + if e.target_id == node_id: + queue.append(e.source_id) + while queue: + nid = queue.popleft() + if nid in visited: + continue + visited.add(nid) + for e in self.edges: + if e.target_id == nid and e.source_id not in visited: + queue.append(e.source_id) + return list(visited) + + # ── Query ──────────────────────────────────────────────────────────────── + + def query(self, + node_type: Optional[str] = None, + label: Optional[str] = None, + decision_kind: Optional[str] = None, + data_kind: Optional[str] = None, + status: Optional[str] = None, + ) -> List[ProvenanceNode]: + """ + Flexible query over nodes with optional filters. + + All provided filters are ANDed together. + + Args: + node_type: Filter by NodeType value. + label: Filter by species label. + decision_kind: Filter DecisionNodes by DecisionKind value. + data_kind: Filter DataNodes by DataKind value. + status: Filter CalculationNodes by job status. + + Returns: + List of matching nodes. + """ + results = list(self.nodes.values()) + if node_type is not None: + results = [n for n in results if n.node_type == _enum_val(node_type)] + if label is not None: + results = [n for n in results if n.label == label] + if decision_kind is not None: + results = [n for n in results + if getattr(n, 'decision_kind', None) == _enum_val(decision_kind)] + if data_kind is not None: + results = [n for n in results + if getattr(n, 'data_kind', None) == _enum_val(data_kind)] + if status is not None: + results = [n for n in results + if getattr(n, 'status', None) == status] + return results + + # ── Serialization ──────────────────────────────────────────────────────── + + def as_dict(self) -> Dict[str, Any]: + """Serialize the full graph to a dict for YAML output.""" + d: Dict[str, Any] = { + 'schema_version': SCHEMA_VERSION, + } + if self.project is not None: + d['project'] = self.project + if self.run_id is not None: + d['run_id'] = self.run_id + d['nodes'] = [node.as_dict() for node in self.nodes.values()] + d['edges'] = [edge.as_dict() for edge in self.edges] + return d + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> 'ProvenanceGraph': + """Reconstruct a ProvenanceGraph from a dict (e.g. loaded from YAML).""" + obj = object.__new__(cls) + obj.project = d.get('project') + obj.run_id = d.get('run_id') + obj.nodes = {} + obj.edges = [] + obj._counter = 0 + for node_dict in d.get('nodes', []): + node = ProvenanceNode.from_dict(node_dict) + obj.nodes[node.node_id] = node + # Update counter to avoid ID collisions on restart. + match = re.search(r'_(\d+)$', node.node_id) + if match: + obj._counter = max(obj._counter, int(match.group(1))) + for edge_dict in d.get('edges', []): + obj.edges.append(ProvenanceEdge.from_dict(edge_dict)) + return obj + + def save(self, path: str) -> None: + """Persist the graph to a YAML file.""" + save_yaml_file(path=path, content=self.as_dict()) + + @classmethod + def load(cls, path: str) -> 'ProvenanceGraph': + """Load a ProvenanceGraph from a YAML file.""" + data = read_yaml_file(path) + if not isinstance(data, dict): + raise ValueError(f'Expected a dict in {path}, got {type(data).__name__}') + return cls.from_dict(data) + + def __len__(self) -> int: + return len(self.nodes) + + def __repr__(self) -> str: + return (f'ProvenanceGraph(project={self.project!r}, ' + f'nodes={len(self.nodes)}, edges={len(self.edges)})') diff --git a/arc/provenance/nodes.py b/arc/provenance/nodes.py new file mode 100644 index 0000000000..bebe43a8cc --- /dev/null +++ b/arc/provenance/nodes.py @@ -0,0 +1,386 @@ +""" +Provenance node and edge types for the ARC provenance DAG. + +Defines the fundamental building blocks of the provenance graph: + +- **Enums**: ``NodeType``, ``DataKind``, ``DecisionKind``, ``EdgeType`` + classify nodes and edges. +- **Node classes**: ``ProvenanceNode`` (base), ``CalculationNode``, + ``DataNode``, ``DecisionNode`` represent vertices in the DAG. +- **Edge class**: ``ProvenanceEdge`` represents a directed, typed + relationship between two nodes. + +All classes follow the ``as_dict()`` / ``from_dict()`` serialization +pattern used throughout ARC (see ``arc.job.pipe.pipe_state``). +""" + +import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + + +def _enum_val(val): + """Extract the plain string value from a str-Enum or pass through a string.""" + return val.value if isinstance(val, Enum) else val + + +# ── Enums ──────────────────────────────────────────────────────────────────── + + +class NodeType(str, Enum): + """Types of nodes in the provenance DAG.""" + species = 'species' + data = 'data' + calculation = 'calculation' + decision = 'decision' + + +class DataKind(str, Enum): + """Sub-classification for DataNode content.""" + geometry = 'geometry' + energy = 'energy' + frequencies = 'frequencies' + imaginary_freq = 'imaginary_freq' + irc_path = 'irc_path' + conformer_set = 'conformer_set' + ts_guess_set = 'ts_guess_set' + + +class DecisionKind(str, Enum): + """Sub-classification for DecisionNode decisions.""" + conformer_selection = 'conformer_selection' + ts_guess_clustering = 'ts_guess_clustering' + ts_guess_selection = 'ts_guess_selection' + ts_guess_selection_failed = 'ts_guess_selection_failed' + ts_validation_freq = 'ts_validation_freq' + ts_validation_nmd = 'ts_validation_nmd' + ts_validation_irc = 'ts_validation_irc' + ts_switch = 'ts_switch' + job_troubleshooting = 'job_troubleshooting' + ts_method_spawning = 'ts_method_spawning' + + +class EdgeType(str, Enum): + """Types of directed edges in the provenance DAG.""" + input_of = 'input_of' + output_of = 'output_of' + triggered_by = 'triggered_by' + selected_by = 'selected_by' + rejected_by = 'rejected_by' + spawned_by = 'spawned_by' + troubleshot_by = 'troubleshot_by' + belongs_to = 'belongs_to' + retried_as = 'retried_as' + fine_of = 'fine_of' + + +# ── Node classes ───────────────────────────────────────────────────────────── + + +class ProvenanceNode(object): + """ + Base class for a node in the provenance DAG. + + Args: + node_id (str): Unique identifier (e.g. ``'species_0'``, ``'calc_17'``). + node_type (str): One of :class:`NodeType` values. + label (str, optional): Species label this node is associated with. + timestamp (str, optional): ISO 8601 creation timestamp. + Auto-generated if not provided. + metadata (dict, optional): Arbitrary extra key-value data. + """ + + def __init__(self, + node_id: str, + node_type: str, + label: Optional[str] = None, + timestamp: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + self.node_id = node_id + self.node_type = _enum_val(node_type) + self.label = label + self.timestamp = timestamp or datetime.datetime.now().isoformat(timespec='seconds') + self.metadata = metadata + + def as_dict(self) -> Dict[str, Any]: + """Serialize to a sparse dict (None and empty values omitted).""" + d: Dict[str, Any] = { + 'node_id': self.node_id, + 'node_type': self.node_type, + } + if self.label is not None: + d['label'] = self.label + if self.timestamp is not None: + d['timestamp'] = self.timestamp + if self.metadata: + d['metadata'] = self.metadata + return d + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> 'ProvenanceNode': + """Reconstruct a ProvenanceNode (or appropriate subclass) from a dict.""" + node_type = d.get('node_type', '') + # Dispatch to the correct subclass based on node_type. + # Keys use plain strings so YAML-deserialized values match. + subclass_map = { + 'calculation': CalculationNode, + 'data': DataNode, + 'decision': DecisionNode, + } + target_cls = subclass_map.get(node_type, cls) + if target_cls is not cls: + return target_cls.from_dict(d) + obj = object.__new__(cls) + obj.node_id = d['node_id'] + obj.node_type = d.get('node_type', '') + obj.label = d.get('label') + obj.timestamp = d.get('timestamp') + obj.metadata = d.get('metadata') + return obj + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.node_id!r}, type={self.node_type!r}, label={self.label!r})' + + +class CalculationNode(ProvenanceNode): + """ + A computational job node (opt, freq, sp, scan, tsg, irc, composite, etc.). + + Args: + node_id (str): Unique identifier. + label (str, optional): Species label. + job_name (str, optional): ARC job name (e.g. ``'opt_a1'``). + job_type (str, optional): Job type (e.g. ``'opt'``, ``'freq'``). + job_adapter (str, optional): ESS adapter (e.g. ``'gaussian'``). + level (str, optional): Level of theory string. + status (str, optional): Job outcome: ``'pending'``, ``'done'``, ``'errored'``. + run_time (str, optional): Wall-clock duration string. + conformer (int, optional): Conformer index, if applicable. + tsg (int, optional): TS guess index, if applicable. + ess_trsh_methods (list, optional): Troubleshooting methods applied. + timestamp (str, optional): ISO 8601 creation timestamp. + metadata (dict, optional): Extra data. + """ + + def __init__(self, + node_id: str, + label: Optional[str] = None, + job_name: Optional[str] = None, + job_type: Optional[str] = None, + job_adapter: Optional[str] = None, + level: Optional[str] = None, + status: Optional[str] = None, + run_time: Optional[str] = None, + conformer: Optional[int] = None, + tsg: Optional[int] = None, + ess_trsh_methods: Optional[List[str]] = None, + timestamp: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + super().__init__(node_id=node_id, node_type=NodeType.calculation, + label=label, timestamp=timestamp, metadata=metadata) + self.job_name = job_name + self.job_type = job_type + self.job_adapter = job_adapter + self.level = level + self.status = status + self.run_time = run_time + self.conformer = conformer + self.tsg = tsg + self.ess_trsh_methods = ess_trsh_methods + + def as_dict(self) -> Dict[str, Any]: + d = super().as_dict() + if self.job_name is not None: + d['job_name'] = self.job_name + if self.job_type is not None: + d['job_type'] = self.job_type + if self.job_adapter is not None: + d['job_adapter'] = self.job_adapter + if self.level is not None: + d['level'] = self.level + if self.status is not None: + d['status'] = self.status + if self.run_time is not None: + d['run_time'] = self.run_time + if self.conformer is not None: + d['conformer'] = self.conformer + if self.tsg is not None: + d['tsg'] = self.tsg + if self.ess_trsh_methods: + d['ess_trsh_methods'] = list(self.ess_trsh_methods) + return d + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> 'CalculationNode': + obj = object.__new__(cls) + obj.node_id = d['node_id'] + obj.node_type = d.get('node_type', NodeType.calculation) + obj.label = d.get('label') + obj.timestamp = d.get('timestamp') + obj.metadata = d.get('metadata') + obj.job_name = d.get('job_name') + obj.job_type = d.get('job_type') + obj.job_adapter = d.get('job_adapter') + obj.level = d.get('level') + obj.status = d.get('status') + obj.run_time = d.get('run_time') + obj.conformer = d.get('conformer') + obj.tsg = d.get('tsg') + obj.ess_trsh_methods = d.get('ess_trsh_methods') + return obj + + +class DataNode(ProvenanceNode): + """ + A data artifact node (geometry, energy, frequencies, etc.). + + Args: + node_id (str): Unique identifier. + label (str, optional): Species label. + data_kind (str, optional): One of :class:`DataKind` values. + value: The scalar or small data payload (energy float, freq list, etc.). + source_path (str, optional): Path to the file containing this data. + timestamp (str, optional): ISO 8601 creation timestamp. + metadata (dict, optional): Extra data. + """ + + def __init__(self, + node_id: str, + label: Optional[str] = None, + data_kind: Optional[str] = None, + value: Any = None, + source_path: Optional[str] = None, + timestamp: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + super().__init__(node_id=node_id, node_type=NodeType.data, + label=label, timestamp=timestamp, metadata=metadata) + self.data_kind = _enum_val(data_kind) if data_kind is not None else None + self.value = value + self.source_path = source_path + + def as_dict(self) -> Dict[str, Any]: + d = super().as_dict() + if self.data_kind is not None: + d['data_kind'] = self.data_kind + if self.value is not None: + d['value'] = self.value + if self.source_path is not None: + d['source_path'] = self.source_path + return d + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> 'DataNode': + obj = object.__new__(cls) + obj.node_id = d['node_id'] + obj.node_type = d.get('node_type', NodeType.data) + obj.label = d.get('label') + obj.timestamp = d.get('timestamp') + obj.metadata = d.get('metadata') + obj.data_kind = d.get('data_kind') + obj.value = d.get('value') + obj.source_path = d.get('source_path') + return obj + + +class DecisionNode(ProvenanceNode): + """ + An algorithmic decision point (conformer selection, TS validation, etc.). + + Args: + node_id (str): Unique identifier. + label (str, optional): Species label. + decision_kind (str, optional): One of :class:`DecisionKind` values. + criteria (dict, optional): The selection/rejection criteria applied. + outcome (str, optional): Human-readable summary of the decision result. + timestamp (str, optional): ISO 8601 creation timestamp. + metadata (dict, optional): Extra data. + """ + + def __init__(self, + node_id: str, + label: Optional[str] = None, + decision_kind: Optional[str] = None, + criteria: Optional[Dict[str, Any]] = None, + outcome: Optional[str] = None, + timestamp: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): + super().__init__(node_id=node_id, node_type=NodeType.decision, + label=label, timestamp=timestamp, metadata=metadata) + self.decision_kind = _enum_val(decision_kind) if decision_kind is not None else None + self.criteria = criteria + self.outcome = outcome + + def as_dict(self) -> Dict[str, Any]: + d = super().as_dict() + if self.decision_kind is not None: + d['decision_kind'] = self.decision_kind + if self.criteria: + d['criteria'] = self.criteria + if self.outcome is not None: + d['outcome'] = self.outcome + return d + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> 'DecisionNode': + obj = object.__new__(cls) + obj.node_id = d['node_id'] + obj.node_type = d.get('node_type', NodeType.decision) + obj.label = d.get('label') + obj.timestamp = d.get('timestamp') + obj.metadata = d.get('metadata') + obj.decision_kind = d.get('decision_kind') + obj.criteria = d.get('criteria') + obj.outcome = d.get('outcome') + return obj + + +# ── Edge class ─────────────────────────────────────────────────────────────── + + +class ProvenanceEdge(object): + """ + A typed directed edge in the provenance DAG. + + Args: + source_id (str): Node ID of the edge source. + target_id (str): Node ID of the edge target. + edge_type (str): One of :class:`EdgeType` values. + metadata (dict, optional): Arbitrary extra key-value data. + """ + + def __init__(self, + source_id: str, + target_id: str, + edge_type: str, + metadata: Optional[Dict[str, Any]] = None, + ): + self.source_id = source_id + self.target_id = target_id + self.edge_type = _enum_val(edge_type) + self.metadata = metadata + + def as_dict(self) -> Dict[str, Any]: + d: Dict[str, Any] = { + 'source_id': self.source_id, + 'target_id': self.target_id, + 'edge_type': self.edge_type, + } + if self.metadata: + d['metadata'] = self.metadata + return d + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> 'ProvenanceEdge': + obj = object.__new__(cls) + obj.source_id = d['source_id'] + obj.target_id = d['target_id'] + obj.edge_type = d.get('edge_type', '') + obj.metadata = d.get('metadata') + return obj + + def __repr__(self) -> str: + return f'ProvenanceEdge({self.source_id!r} --{self.edge_type}--> {self.target_id!r})' diff --git a/arc/provenance/provenance_test.py b/arc/provenance/provenance_test.py new file mode 100644 index 0000000000..0db5aecef3 --- /dev/null +++ b/arc/provenance/provenance_test.py @@ -0,0 +1,626 @@ +"""Tests for the arc.provenance package — nodes, edges, and ProvenanceGraph.""" + +import os +import shutil +import tempfile +import unittest + +from arc.provenance.graph import SCHEMA_VERSION, ProvenanceGraph +from arc.provenance.nodes import ( + CalculationNode, + DataKind, + DataNode, + DecisionKind, + DecisionNode, + EdgeType, + NodeType, + ProvenanceEdge, + ProvenanceNode, +) + + +class TestEnums(unittest.TestCase): + """Verify that enums are str-based and contain expected values.""" + + def test_node_type_is_str(self): + self.assertIsInstance(NodeType.species, str) + self.assertEqual(NodeType.calculation, 'calculation') + + def test_data_kind_values(self): + self.assertIn('geometry', [dk.value for dk in DataKind]) + self.assertIn('energy', [dk.value for dk in DataKind]) + + def test_decision_kind_values(self): + expected = {'conformer_selection', 'ts_guess_clustering', 'ts_guess_selection', + 'ts_guess_selection_failed', 'ts_validation_freq', 'ts_validation_nmd', + 'ts_validation_irc', 'ts_switch', 'job_troubleshooting', 'ts_method_spawning'} + actual = {dk.value for dk in DecisionKind} + self.assertEqual(expected, actual) + + def test_edge_type_values(self): + self.assertIn('input_of', [et.value for et in EdgeType]) + self.assertIn('selected_by', [et.value for et in EdgeType]) + self.assertIn('rejected_by', [et.value for et in EdgeType]) + + +class TestProvenanceNode(unittest.TestCase): + """Test the base ProvenanceNode class.""" + + def test_creation(self): + node = ProvenanceNode(node_id='species_1', node_type=NodeType.species, label='ethanol') + self.assertEqual(node.node_id, 'species_1') + self.assertEqual(node.node_type, 'species') + self.assertEqual(node.label, 'ethanol') + self.assertIsNotNone(node.timestamp) + + def test_as_dict_sparse(self): + node = ProvenanceNode(node_id='species_1', node_type=NodeType.species) + d = node.as_dict() + self.assertIn('node_id', d) + self.assertIn('node_type', d) + self.assertNotIn('label', d) + self.assertNotIn('metadata', d) + + def test_as_dict_with_metadata(self): + node = ProvenanceNode(node_id='species_1', node_type=NodeType.species, + label='H2O', metadata={'is_ts': True}) + d = node.as_dict() + self.assertEqual(d['metadata'], {'is_ts': True}) + + def test_from_dict_roundtrip(self): + node = ProvenanceNode(node_id='species_1', node_type=NodeType.species, + label='ethanol', metadata={'is_ts': False}) + d = node.as_dict() + restored = ProvenanceNode.from_dict(d) + self.assertEqual(restored.node_id, 'species_1') + self.assertEqual(restored.node_type, 'species') + self.assertEqual(restored.label, 'ethanol') + + def test_from_dict_dispatches_to_subclass(self): + d = {'node_id': 'calc_1', 'node_type': 'calculation', 'job_name': 'opt_a1'} + restored = ProvenanceNode.from_dict(d) + self.assertIsInstance(restored, CalculationNode) + self.assertEqual(restored.job_name, 'opt_a1') + + def test_repr(self): + node = ProvenanceNode(node_id='species_1', node_type=NodeType.species, label='ethanol') + self.assertIn('species_1', repr(node)) + + +class TestCalculationNode(unittest.TestCase): + """Test CalculationNode creation and serialization.""" + + def test_creation(self): + node = CalculationNode(node_id='calc_1', label='ethanol', job_name='opt_a1', + job_type='opt', job_adapter='gaussian', + level='wb97xd/def2-tzvp', status='done') + self.assertEqual(node.node_type, 'calculation') + self.assertEqual(node.job_name, 'opt_a1') + self.assertEqual(node.status, 'done') + + def test_as_dict_sparse(self): + node = CalculationNode(node_id='calc_1', label='ethanol', job_name='opt_a1') + d = node.as_dict() + self.assertIn('job_name', d) + self.assertNotIn('job_adapter', d) + self.assertNotIn('ess_trsh_methods', d) + + def test_from_dict_roundtrip(self): + node = CalculationNode(node_id='calc_1', label='ethanol', job_name='opt_a1', + job_type='opt', status='errored', + ess_trsh_methods=['SCF=QC', 'int=grid=ultrafine']) + d = node.as_dict() + restored = CalculationNode.from_dict(d) + self.assertEqual(restored.job_name, 'opt_a1') + self.assertEqual(restored.status, 'errored') + self.assertEqual(restored.ess_trsh_methods, ['SCF=QC', 'int=grid=ultrafine']) + self.assertIsNone(restored.conformer) + + +class TestDataNode(unittest.TestCase): + """Test DataNode creation and serialization.""" + + def test_creation(self): + node = DataNode(node_id='data_1', label='ethanol', + data_kind=DataKind.energy, value=-79.123456) + self.assertEqual(node.node_type, 'data') + self.assertEqual(node.data_kind, 'energy') + self.assertEqual(node.value, -79.123456) + + def test_from_dict_roundtrip(self): + node = DataNode(node_id='data_1', label='ethanol', + data_kind=DataKind.frequencies, value=[3200.5, 1500.3, 800.1]) + d = node.as_dict() + restored = DataNode.from_dict(d) + self.assertEqual(restored.data_kind, 'frequencies') + self.assertEqual(restored.value, [3200.5, 1500.3, 800.1]) + + +class TestDecisionNode(unittest.TestCase): + """Test DecisionNode creation and serialization.""" + + def test_creation(self): + node = DecisionNode(node_id='decision_1', label='TS0', + decision_kind=DecisionKind.ts_guess_selection, + outcome='Selected TSGuess #3 (energy=-150.2 kJ/mol)') + self.assertEqual(node.node_type, 'decision') + self.assertEqual(node.decision_kind, 'ts_guess_selection') + self.assertIn('TSGuess #3', node.outcome) + + def test_from_dict_roundtrip(self): + node = DecisionNode(node_id='decision_1', label='TS0', + decision_kind=DecisionKind.job_troubleshooting, + criteria={'error_keywords': ['SCF', 'Memory'], + 'applied': 'SCF=QC'}, + outcome='Retrying with SCF=QC') + d = node.as_dict() + restored = DecisionNode.from_dict(d) + self.assertEqual(restored.decision_kind, 'job_troubleshooting') + self.assertEqual(restored.criteria['error_keywords'], ['SCF', 'Memory']) + self.assertEqual(restored.outcome, 'Retrying with SCF=QC') + + +class TestProvenanceEdge(unittest.TestCase): + """Test ProvenanceEdge creation and serialization.""" + + def test_creation(self): + edge = ProvenanceEdge(source_id='species_1', target_id='calc_1', + edge_type=EdgeType.input_of) + self.assertEqual(edge.source_id, 'species_1') + self.assertEqual(edge.edge_type, 'input_of') + + def test_as_dict_sparse(self): + edge = ProvenanceEdge(source_id='a', target_id='b', edge_type=EdgeType.output_of) + d = edge.as_dict() + self.assertNotIn('metadata', d) + + def test_from_dict_roundtrip(self): + edge = ProvenanceEdge(source_id='calc_1', target_id='data_1', + edge_type=EdgeType.output_of, + metadata={'reason': 'rerun'}) + d = edge.as_dict() + restored = ProvenanceEdge.from_dict(d) + self.assertEqual(restored.source_id, 'calc_1') + self.assertEqual(restored.metadata, {'reason': 'rerun'}) + + def test_repr(self): + edge = ProvenanceEdge(source_id='a', target_id='b', edge_type=EdgeType.selected_by) + self.assertIn('selected_by', repr(edge)) + + +class TestProvenanceGraph(unittest.TestCase): + """Test ProvenanceGraph CRUD, traversal, query, and serialization.""" + + def setUp(self): + self.graph = ProvenanceGraph(project='test_project') + + def test_add_species_node(self): + nid = self.graph.add_species_node(label='ethanol') + self.assertIn(nid, self.graph.nodes) + self.assertEqual(self.graph.nodes[nid].node_type, 'species') + self.assertEqual(self.graph.nodes[nid].label, 'ethanol') + + def test_add_calculation_node(self): + nid = self.graph.add_calculation_node(label='ethanol', job_name='opt_a1', + job_type='opt', status='pending') + node = self.graph.get_node(nid) + self.assertIsInstance(node, CalculationNode) + self.assertEqual(node.job_name, 'opt_a1') + + def test_add_data_node(self): + nid = self.graph.add_data_node(label='ethanol', data_kind=DataKind.energy, + value=-79.5) + node = self.graph.get_node(nid) + self.assertIsInstance(node, DataNode) + self.assertEqual(node.value, -79.5) + + def test_add_decision_node(self): + nid = self.graph.add_decision_node(label='TS0', + decision_kind=DecisionKind.ts_guess_selection, + outcome='Selected TSG #2') + node = self.graph.get_node(nid) + self.assertIsInstance(node, DecisionNode) + self.assertEqual(node.outcome, 'Selected TSG #2') + + def test_node_id_auto_increment(self): + id1 = self.graph.add_species_node(label='A') + id2 = self.graph.add_species_node(label='B') + id3 = self.graph.add_calculation_node(label='A', job_name='opt_a1') + self.assertEqual(id1, 'species_1') + self.assertEqual(id2, 'species_2') + self.assertEqual(id3, 'calc_3') + + def test_duplicate_node_skipped(self): + node = ProvenanceNode(node_id='species_1', node_type=NodeType.species, label='X') + self.graph.add_node(node) + self.graph.add_node(node) + self.assertEqual(len(self.graph.nodes), 1) + + def test_add_edge(self): + sid = self.graph.add_species_node(label='ethanol') + cid = self.graph.add_calculation_node(label='ethanol', job_name='opt_a1') + edge = self.graph.add_edge(sid, cid, EdgeType.input_of) + self.assertEqual(len(self.graph.edges), 1) + self.assertEqual(edge.edge_type, 'input_of') + + def test_get_edges_from_and_to(self): + sid = self.graph.add_species_node(label='A') + c1 = self.graph.add_calculation_node(label='A', job_name='opt_a1') + c2 = self.graph.add_calculation_node(label='A', job_name='freq_a2') + self.graph.add_edge(sid, c1, EdgeType.input_of) + self.graph.add_edge(sid, c2, EdgeType.input_of) + self.assertEqual(len(self.graph.get_edges_from(sid)), 2) + self.assertEqual(len(self.graph.get_edges_to(c1)), 1) + + def test_get_nodes_by_type(self): + self.graph.add_species_node(label='A') + self.graph.add_species_node(label='B') + self.graph.add_calculation_node(label='A', job_name='opt') + species_nodes = self.graph.get_nodes_by_type(NodeType.species) + self.assertEqual(len(species_nodes), 2) + calc_nodes = self.graph.get_nodes_by_type(NodeType.calculation) + self.assertEqual(len(calc_nodes), 1) + + def test_get_nodes_by_type_with_label_filter(self): + self.graph.add_species_node(label='A') + self.graph.add_species_node(label='B') + self.graph.add_calculation_node(label='A', job_name='opt') + self.graph.add_calculation_node(label='B', job_name='opt') + a_calcs = self.graph.get_nodes_by_type(NodeType.calculation, label='A') + self.assertEqual(len(a_calcs), 1) + + def test_get_nodes_by_label(self): + self.graph.add_species_node(label='ethanol') + self.graph.add_calculation_node(label='ethanol', job_name='opt') + self.graph.add_calculation_node(label='methane', job_name='opt') + eth_nodes = self.graph.get_nodes_by_label('ethanol') + self.assertEqual(len(eth_nodes), 2) + + def test_find_species_node(self): + sid = self.graph.add_species_node(label='ethanol') + self.assertEqual(self.graph.find_species_node('ethanol'), sid) + self.assertIsNone(self.graph.find_species_node('missing')) + + def test_find_calc_node(self): + self.graph.add_calculation_node(label='A', job_name='opt_a1') + cid = self.graph.find_calc_node('A', 'opt_a1') + self.assertIsNotNone(cid) + self.assertIsNone(self.graph.find_calc_node('A', 'missing')) + + def test_update_node(self): + cid = self.graph.add_calculation_node(label='A', job_name='opt', status='pending') + self.assertTrue(self.graph.update_node(cid, status='done', run_time='00:05:30')) + node = self.graph.get_node(cid) + self.assertEqual(node.status, 'done') + self.assertEqual(node.run_time, '00:05:30') + + def test_update_node_missing(self): + self.assertFalse(self.graph.update_node('nonexistent', status='done')) + + def test_get_edges_by_type(self): + sid = self.graph.add_species_node(label='A') + c1 = self.graph.add_calculation_node(label='A', job_name='opt') + d1 = self.graph.add_data_node(label='A', data_kind=DataKind.energy) + self.graph.add_edge(sid, c1, EdgeType.input_of) + self.graph.add_edge(c1, d1, EdgeType.output_of) + self.assertEqual(len(self.graph.get_edges_by_type(EdgeType.input_of)), 1) + self.assertEqual(len(self.graph.get_edges_by_type(EdgeType.output_of)), 1) + self.assertEqual(len(self.graph.get_edges_by_type(EdgeType.selected_by)), 0) + + # ── Traversal ──────────────────────────────────────────────────────────── + + def test_descendants(self): + """species -> calc -> data -> decision""" + sid = self.graph.add_species_node(label='A') + cid = self.graph.add_calculation_node(label='A', job_name='opt') + did = self.graph.add_data_node(label='A', data_kind=DataKind.geometry) + dec = self.graph.add_decision_node(label='A', decision_kind=DecisionKind.conformer_selection) + self.graph.add_edge(sid, cid, EdgeType.input_of) + self.graph.add_edge(cid, did, EdgeType.output_of) + self.graph.add_edge(did, dec, EdgeType.selected_by) + desc = self.graph.descendants(sid) + self.assertEqual(set(desc), {cid, did, dec}) + self.assertNotIn(sid, desc) + + def test_ancestors(self): + """Reverse traversal.""" + sid = self.graph.add_species_node(label='A') + cid = self.graph.add_calculation_node(label='A', job_name='opt') + did = self.graph.add_data_node(label='A', data_kind=DataKind.energy) + self.graph.add_edge(sid, cid, EdgeType.input_of) + self.graph.add_edge(cid, did, EdgeType.output_of) + anc = self.graph.ancestors(did) + self.assertEqual(set(anc), {sid, cid}) + + def test_no_descendants(self): + sid = self.graph.add_species_node(label='A') + self.assertEqual(self.graph.descendants(sid), []) + + # ── Query ──────────────────────────────────────────────────────────────── + + def test_query_by_node_type(self): + self.graph.add_species_node(label='A') + self.graph.add_calculation_node(label='A', job_name='opt') + results = self.graph.query(node_type=NodeType.species) + self.assertEqual(len(results), 1) + + def test_query_by_decision_kind(self): + self.graph.add_decision_node(label='A', decision_kind=DecisionKind.ts_guess_selection) + self.graph.add_decision_node(label='A', decision_kind=DecisionKind.job_troubleshooting) + results = self.graph.query(decision_kind=DecisionKind.ts_guess_selection) + self.assertEqual(len(results), 1) + + def test_query_by_status(self): + self.graph.add_calculation_node(label='A', job_name='opt', status='done') + self.graph.add_calculation_node(label='A', job_name='freq', status='errored') + done = self.graph.query(status='done') + self.assertEqual(len(done), 1) + self.assertEqual(done[0].job_name, 'opt') + + def test_query_combined_filters(self): + self.graph.add_calculation_node(label='A', job_name='opt', status='done') + self.graph.add_calculation_node(label='B', job_name='opt', status='done') + results = self.graph.query(node_type=NodeType.calculation, label='A', status='done') + self.assertEqual(len(results), 1) + self.assertEqual(results[0].label, 'A') + + # ── Serialization ──────────────────────────────────────────────────────── + + def test_as_dict_structure(self): + self.graph.add_species_node(label='A') + d = self.graph.as_dict() + self.assertEqual(d['schema_version'], SCHEMA_VERSION) + self.assertEqual(d['project'], 'test_project') + self.assertIsInstance(d['nodes'], list) + self.assertIsInstance(d['edges'], list) + + def test_from_dict_roundtrip(self): + sid = self.graph.add_species_node(label='ethanol') + cid = self.graph.add_calculation_node(label='ethanol', job_name='opt_a1', + status='done') + self.graph.add_edge(sid, cid, EdgeType.input_of) + d = self.graph.as_dict() + restored = ProvenanceGraph.from_dict(d) + self.assertEqual(len(restored.nodes), 2) + self.assertEqual(len(restored.edges), 1) + self.assertEqual(restored.project, 'test_project') + self.assertIsInstance(restored.get_node(cid), CalculationNode) + self.assertEqual(restored.get_node(cid).status, 'done') + + def test_restart_continues_counter(self): + """After loading a graph, new node IDs should not collide with existing ones.""" + self.graph.add_species_node(label='A') + self.graph.add_species_node(label='B') + self.graph.add_calculation_node(label='A', job_name='opt') + d = self.graph.as_dict() + restored = ProvenanceGraph.from_dict(d) + new_id = restored.add_species_node(label='C') + # _counter should be at least 3 (from species_1, species_2, calc_3), + # so next ID should be species_4 or higher + self.assertNotIn(new_id, ['species_1', 'species_2', 'calc_3']) + + def test_save_and_load(self): + tmp_dir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, tmp_dir) + path = os.path.join(tmp_dir, 'provenance_graph.yml') + sid = self.graph.add_species_node(label='ethanol') + cid = self.graph.add_calculation_node(label='ethanol', job_name='opt_a1', + job_type='opt', status='done') + did = self.graph.add_data_node(label='ethanol', data_kind=DataKind.energy, + value=-79.5) + dec = self.graph.add_decision_node(label='ethanol', + decision_kind=DecisionKind.conformer_selection, + outcome='Selected conformer #0') + self.graph.add_edge(sid, cid, EdgeType.input_of) + self.graph.add_edge(cid, did, EdgeType.output_of) + self.graph.add_edge(did, dec, EdgeType.selected_by) + self.graph.save(path) + self.assertTrue(os.path.isfile(path)) + loaded = ProvenanceGraph.load(path) + self.assertEqual(len(loaded.nodes), 4) + self.assertEqual(len(loaded.edges), 3) + self.assertIsInstance(loaded.get_node(cid), CalculationNode) + self.assertIsInstance(loaded.get_node(did), DataNode) + self.assertIsInstance(loaded.get_node(dec), DecisionNode) + + def test_len_and_repr(self): + self.assertEqual(len(self.graph), 0) + self.graph.add_species_node(label='A') + self.assertEqual(len(self.graph), 1) + self.assertIn('test_project', repr(self.graph)) + + +class TestProvenanceGraphWorkflow(unittest.TestCase): + """ + Integration-style test: build a realistic provenance graph for a species + going through opt → freq → sp, with a troubleshooting retry on freq. + """ + + def test_realistic_workflow(self): + g = ProvenanceGraph(project='workflow_test') + + # Species initialized + sid = g.add_species_node(label='ethanol') + + # Opt job succeeds + opt_id = g.add_calculation_node(label='ethanol', job_name='opt_a1', + job_type='opt', status='done') + g.add_edge(sid, opt_id, EdgeType.input_of) + opt_geo = g.add_data_node(label='ethanol', data_kind=DataKind.geometry, + source_path='calcs/opt_a1/output.log') + g.add_edge(opt_id, opt_geo, EdgeType.output_of) + + # Freq job fails + freq1_id = g.add_calculation_node(label='ethanol', job_name='freq_a2', + job_type='freq', status='errored') + g.add_edge(opt_geo, freq1_id, EdgeType.input_of) + + # Troubleshooting decision + trsh_id = g.add_decision_node(label='ethanol', + decision_kind=DecisionKind.job_troubleshooting, + criteria={'error_keywords': ['SCF']}, + outcome='Retrying with SCF=QC') + g.add_edge(freq1_id, trsh_id, EdgeType.troubleshot_by) + + # Freq job retried and succeeds + freq2_id = g.add_calculation_node(label='ethanol', job_name='freq_a3', + job_type='freq', status='done', + ess_trsh_methods=['SCF=QC']) + g.add_edge(trsh_id, freq2_id, EdgeType.spawned_by) + g.add_edge(freq1_id, freq2_id, EdgeType.retried_as) + freq_data = g.add_data_node(label='ethanol', data_kind=DataKind.frequencies, + value=[3200.5, 1500.3]) + g.add_edge(freq2_id, freq_data, EdgeType.output_of) + + # SP job succeeds + sp_id = g.add_calculation_node(label='ethanol', job_name='sp_a4', + job_type='sp', status='done') + g.add_edge(opt_geo, sp_id, EdgeType.input_of) + sp_energy = g.add_data_node(label='ethanol', data_kind=DataKind.energy, + value=-79.123456) + g.add_edge(sp_id, sp_energy, EdgeType.output_of) + + # Verify graph structure + self.assertEqual(len(g.nodes), 9) + self.assertEqual(len(g.edges), 9) + + # Verify traversal: ancestors of the final energy should trace back to species + anc = g.ancestors(sp_energy) + self.assertIn(sid, anc) + self.assertIn(opt_id, anc) + self.assertIn(sp_id, anc) + + # Verify query: find all troubleshooting decisions + trsh_decisions = g.query(decision_kind=DecisionKind.job_troubleshooting) + self.assertEqual(len(trsh_decisions), 1) + self.assertEqual(trsh_decisions[0].criteria['error_keywords'], ['SCF']) + + # Verify query: find all errored calculations + errored = g.query(node_type=NodeType.calculation, status='errored') + self.assertEqual(len(errored), 1) + self.assertEqual(errored[0].job_name, 'freq_a2') + + # Verify traversal: descendants of the troubleshooting decision + # should include the retried freq job and its output + desc = g.descendants(trsh_id) + self.assertIn(freq2_id, desc) + self.assertIn(freq_data, desc) + + +class TestEdgeCases(unittest.TestCase): + """Tests for edge cases identified during code review.""" + + def setUp(self): + self.graph = ProvenanceGraph(project='edge_case_test') + + def test_add_edge_warns_on_nonexistent_nodes(self): + """add_edge should still work but log warnings for missing nodes.""" + sid = self.graph.add_species_node(label='A') + edge = self.graph.add_edge(sid, 'nonexistent_target', EdgeType.input_of) + self.assertEqual(len(self.graph.edges), 1) + self.assertEqual(edge.target_id, 'nonexistent_target') + + def test_roundtrip_preserves_zero_value(self): + """DataNode with value=0 (falsy) must survive serialization.""" + nid = self.graph.add_data_node(label='A', data_kind=DataKind.energy, value=0) + d = self.graph.as_dict() + restored = ProvenanceGraph.from_dict(d) + node = restored.get_node(nid) + self.assertIsInstance(node, DataNode) + self.assertEqual(node.value, 0) + + def test_roundtrip_preserves_false_in_metadata(self): + """Metadata with False values must survive serialization.""" + node = ProvenanceNode(node_id='species_99', node_type=NodeType.species, + label='X', metadata={'is_ts': False, 'converged': False}) + self.graph.add_node(node) + d = self.graph.as_dict() + restored = ProvenanceGraph.from_dict(d) + restored_node = restored.get_node('species_99') + self.assertEqual(restored_node.metadata['is_ts'], False) + self.assertEqual(restored_node.metadata['converged'], False) + + def test_roundtrip_omits_empty_ess_trsh_methods(self): + """CalculationNode with ess_trsh_methods=[] should omit it from dict.""" + node = CalculationNode(node_id='calc_99', label='A', ess_trsh_methods=[]) + d = node.as_dict() + self.assertNotIn('ess_trsh_methods', d) + + def test_ancestors_with_diamond_dependency(self): + """DAG diamond: A -> B -> D, A -> C -> D — ancestors(D) = {A, B, C}.""" + a = self.graph.add_species_node(label='A') + b = self.graph.add_calculation_node(label='A', job_name='opt') + c = self.graph.add_calculation_node(label='A', job_name='freq') + d = self.graph.add_data_node(label='A', data_kind=DataKind.energy) + self.graph.add_edge(a, b, EdgeType.input_of) + self.graph.add_edge(a, c, EdgeType.input_of) + self.graph.add_edge(b, d, EdgeType.output_of) + self.graph.add_edge(c, d, EdgeType.output_of) + anc = self.graph.ancestors(d) + self.assertEqual(set(anc), {a, b, c}) + + def test_descendants_handles_self_loop(self): + """If a self-loop is accidentally created, traversal should not infinite-loop.""" + nid = self.graph.add_species_node(label='A') + self.graph.add_edge(nid, nid, EdgeType.input_of) + desc = self.graph.descendants(nid) + self.assertIn(nid, desc) + + def test_query_enum_and_string_equivalence(self): + """Query with NodeType enum and plain string should return identical results.""" + self.graph.add_calculation_node(label='A', job_name='opt', status='done') + r1 = self.graph.query(node_type=NodeType.calculation) + r2 = self.graph.query(node_type='calculation') + self.assertEqual(len(r1), len(r2)) + self.assertEqual(r1[0].node_id, r2[0].node_id) + + def test_counter_with_mixed_prefixes_after_restart(self): + """Counter should track max across ALL prefixes, not per-prefix.""" + self.graph.add_species_node(label='A') # species_1 + self.graph.add_species_node(label='B') # species_2 + self.graph.add_calculation_node(label='A', job_name='opt') # calc_3 + d = self.graph.as_dict() + restored = ProvenanceGraph.from_dict(d) + # Counter should be >= 3, so next ID suffix is >= 4 + new_id = restored.add_data_node(label='A', data_kind=DataKind.energy) + suffix = int(new_id.split('_')[-1]) + self.assertGreaterEqual(suffix, 4) + + def test_render_all_edge_types(self): + """Verify render_provenance_graph handles every EdgeType without errors.""" + try: + import graphviz as gv_mod + except ImportError: + self.skipTest('graphviz not installed') + from arc.plotter import render_provenance_graph + g = ProvenanceGraph(project='edge_type_test') + n1 = g.add_species_node(label='A') + n2 = g.add_calculation_node(label='A', job_name='opt', status='done') + g.add_data_node(label='A', data_kind=DataKind.energy) + g.add_decision_node(label='A', decision_kind=DecisionKind.conformer_selection) + g.add_calculation_node(label='A', job_name='opt2', status='errored') + for et in list(EdgeType): + g.add_edge(n1, n2, et) + gv = render_provenance_graph(g, run_label='test') + dot = gv.source + self.assertIn('species_1', dot) + self.assertIn('calc_2', dot) + + def test_render_none_labels(self): + """Nodes with label=None should render using node_id as fallback.""" + try: + import graphviz as gv_mod + except ImportError: + self.skipTest('graphviz not installed') + from arc.plotter import render_provenance_graph + g = ProvenanceGraph(project='none_label_test') + g.add_species_node(label=None) + g.add_calculation_node(label=None, job_name='opt', status='pending') + gv = render_provenance_graph(g, run_label='test') + dot = gv.source + # Should not crash; node_id is used as fallback for species + self.assertIn('species_1', dot) + + +if __name__ == '__main__': + unittest.main() diff --git a/arc/scheduler.py b/arc/scheduler.py index 6f4d81f39b..df52dbb53b 100644 --- a/arc/scheduler.py +++ b/arc/scheduler.py @@ -9,9 +9,8 @@ import pprint import shutil import time - import numpy as np -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import Any, TYPE_CHECKING, List, Optional, Tuple, Union import arc.parser.parser as parser from arc import plotter @@ -58,6 +57,7 @@ ) from arc.species.perceive import perceive_molecule_from_xyz from arc.species.vectors import get_angle, calculate_dihedral_angle +from arc.provenance import (ProvenanceGraph, EdgeType, NodeType, DataKind, DecisionKind) if TYPE_CHECKING: from arc.job.adapter import JobAdapter @@ -298,12 +298,22 @@ def __init__(self, self.output_multi_spc = dict() self.report_e_elect = report_e_elect self.skip_nmd = skip_nmd + self.provenance = {'version': 1, + 'project': self.project, + 'run_id': f'{self.project}_{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}', + 'started_at': datetime.datetime.now().isoformat(timespec='seconds'), + 'events': list(), + } + self.provenance_path = os.path.join(self.project_directory, 'output', 'provenance.yml') + self.graph = ProvenanceGraph(project=self.project, run_id=self.provenance['run_id']) + self.graph_path = os.path.join(self.project_directory, 'output', 'provenance_graph.yml') self.species_dict, self.rxn_dict = dict(), dict() for species in self.species_list: self.species_dict[species.label] = species for rxn in self.rxn_list: self.rxn_dict[rxn.index] = rxn + self._initialize_provenance() if self.restart_dict is not None: self.output = self.restart_dict['output'] if 'output' in self.restart_dict else dict() self.output_multi_spc = self.restart_dict['output_multi_spc'] if 'output_multi_spc' in self.restart_dict else dict() @@ -326,6 +336,8 @@ def __init__(self, self.orbitals_level = orbitals_level self.unique_species_labels = list() self.save_restart = False + if self.restart_dict is not None: + self._sanitize_restart_output() if len(self.rxn_list): rxn_info_path = self.make_reaction_labels_info_file() @@ -369,6 +381,12 @@ def __init__(self, self.species_list.append(ts_species) self.species_dict[ts_species.label] = ts_species self.initialize_output_dict(ts_species.label) + self.record_provenance_event(event_type='species_initialized', + label=ts_species.label, + is_ts=True, + ) + if self.graph.find_species_node(ts_species.label) is None: + self.graph.add_species_node(label=ts_species.label, is_ts=True) else: # The TS species was already loaded from a restart dict or an Arkane YAML file. ts_species = None @@ -584,6 +602,81 @@ def _flush_pending_pipe_conf_sp(self) -> None: for i in sorted(conformer_indices - piped): self.run_sp_job(label=label, level=self.conformer_sp_level, conformer=i) + def _initialize_provenance(self): + """Load previous provenance when restarting and record the current run start. + + On a fresh run (no restart_dict), the event log and graph start empty. + On a restart, the previous event log and graph are loaded and deduplicated. + """ + is_restart = self.restart_dict is not None + if is_restart and os.path.isfile(self.provenance_path): + try: + provenance = read_yaml_file(self.provenance_path) + except Exception: + logger.warning('Could not parse existing provenance.yml; starting a fresh provenance log.') + provenance = None + if isinstance(provenance, dict): + raw_events = provenance.get('events', list()) + if isinstance(raw_events, list) and all(isinstance(e, dict) for e in raw_events): + self.provenance['events'] = raw_events + else: + logger.warning('Existing provenance.yml has invalid events; starting with an empty event log.') + if is_restart and os.path.isfile(self.graph_path): + try: + self.graph = ProvenanceGraph.load(self.graph_path) + except Exception: + logger.warning('Could not parse existing provenance_graph.yml; starting a fresh graph.') + already_initialized = {e['label'] for e in self.provenance['events'] + if e.get('event_type') == 'species_initialized' and isinstance(e.get('label'), str)} + already_in_graph = {n.label for n in self.graph.get_nodes_by_type(NodeType.species)} + for species in self.species_list: + if species.label not in already_initialized: + self.record_provenance_event(event_type='species_initialized', + label=species.label, + is_ts=species.is_ts, + ) + if species.label not in already_in_graph: + self.graph.add_species_node(label=species.label, is_ts=species.is_ts) + + def record_provenance_event(self, + event_type: str, + label: Optional[str] = None, + **data: Any, + ): + """Append a provenance event and persist the event log.""" + max_id = max((e.get('event_id', 0) for e in self.provenance['events']), default=0) + event = {'event_id': max_id + 1, + 'event_type': event_type, + 'timestamp': datetime.datetime.now().isoformat(timespec='seconds'), + } + if label is not None: + event['label'] = label + for key, value in data.items(): + if value is not None and value != '' and value != list(): + event[key] = value + self.provenance['events'].append(event) + self.save_provenance() + + def save_provenance(self): + """Persist the provenance event log. The graph is saved lazily via save_provenance_graph().""" + output_directory = os.path.dirname(self.provenance_path) + if not os.path.isdir(output_directory): + os.makedirs(output_directory) + save_yaml_file(path=self.provenance_path, content=self.provenance) + + def save_provenance_graph(self): + """Persist the provenance graph to disk. Called at checkpoints and finalization, not per-event.""" + self.graph.save(self.graph_path) + + def finalize_provenance(self): + """Render final provenance artifacts after the run completes.""" + self.provenance['ended_at'] = datetime.datetime.now().isoformat(timespec='seconds') + self.graph.save(self.graph_path) + plotter.save_provenance_artifacts(project_directory=self.project_directory, + provenance=self.provenance, + graph=self.graph, + ) + def schedule_jobs(self): """ The main job scheduling block @@ -847,6 +940,7 @@ def schedule_jobs(self): # Generate a TS report: self.generate_final_ts_guess_report() + self.finalize_provenance() def run_job(self, job_type: str, @@ -873,6 +967,8 @@ def run_job(self, torsions: Optional[List[List[int]]] = None, times_rerun: int = 0, tsg: Optional[int] = None, + provenance_parent_job: Optional[str] = None, + provenance_reason: Optional[str] = None, xyz: Optional[Union[dict, List[dict]]]= None, ): """ @@ -901,6 +997,8 @@ def run_job(self, torsions (List[List[int]], optional): The 0-indexed atom indices of the torsion(s). trsh (str, optional): A troubleshooting keyword to be used in input files. tsg (int, optional): TSGuess number if optimizing TS guesses. + provenance_parent_job (str, optional): The job_name of the parent job that triggered this one. + provenance_reason (str, optional): Why this job was spawned (e.g., 'rerun', 'ess_troubleshoot', 'fine_opt'). xyz (Union[dict, List[dict]], optional): The 3D coordinates for the species. """ max_job_time = max_job_time or self.max_job_time # if it's None, set to default @@ -1004,6 +1102,44 @@ def run_job(self, if job.server is not None and job.server not in self.servers: self.servers.append(job.server) self.check_max_simultaneous_jobs_limit(job.server) + level_repr = None if job.level is None else str(job.level) + provenance_label = '+'.join(label) if isinstance(label, list) else label + self.record_provenance_event( + event_type='job_started', + label=provenance_label, + is_ts=self.species_dict[label].is_ts if isinstance(label, str) and label in self.species_dict else None, + job_key=f'{provenance_label}:{job.job_name}', + job_name=job.job_name, + job_type=job.job_type, + job_adapter=job.job_adapter, + level=level_repr, + execution_type=job.execution_type, + ess_trsh_methods=job.ess_trsh_methods, + conformer=conformer, + tsg=tsg, + provenance_parent_job=provenance_parent_job, + provenance_reason=provenance_reason, + ) + # ── Graph: add CalculationNode ── + calc_node_id = self.graph.add_calculation_node( + label=provenance_label, + job_name=job.job_name, + job_type=job.job_type, + job_adapter=job.job_adapter, + level=level_repr, + status='pending', + conformer=conformer, + tsg=tsg, + ess_trsh_methods=job.ess_trsh_methods if job.ess_trsh_methods else None, + ) + species_node_id = self.graph.find_species_node(provenance_label) + if species_node_id is not None: + self.graph.add_edge(species_node_id, calc_node_id, EdgeType.belongs_to) + if provenance_parent_job: + parent_node_id = self.graph.find_calc_node(provenance_label, provenance_parent_job) + if parent_node_id is not None: + edge_type = EdgeType.fine_of if provenance_reason == 'fine_opt' else EdgeType.retried_as + self.graph.add_edge(parent_node_id, calc_node_id, edge_type) job.execute() self.save_restart_dict() @@ -1124,6 +1260,26 @@ def end_job(self, job: 'JobAdapter', self.timer = False job.write_completed_job_to_csv_file() logger.info(f' Ending job {job_name} for {label} (run time: {job.run_time})') + job_status_str = job.job_status[1]['status'] if job.job_status[1]['status'] else job.job_status[0] + self.record_provenance_event( + event_type='job_finished', + label=label, + is_ts=self.species_dict[label].is_ts if isinstance(label, str) and label in self.species_dict else None, + job_key=f'{label}:{job.job_name}', + job_name=job.job_name, + job_type=job.job_type, + status=job_status_str, + keywords=job.job_status[1]['keywords'], + error=job.job_status[1]['error'], + run_time=str(job.run_time) if job.run_time is not None else None, + ) + # ── Graph: update CalculationNode status ── + prov_label = '+'.join(label) if isinstance(label, list) else label + calc_nid = self.graph.find_calc_node(prov_label, job.job_name) + if calc_nid is not None: + self.graph.update_node(calc_nid, + status=job_status_str, + run_time=str(job.run_time) if job.run_time is not None else None) if job.job_status[0] != 'done': return False if job.job_adapter in ['gaussian', 'terachem'] and os.path.isfile(os.path.join(job.local_path, 'check.chk')) \ @@ -1180,6 +1336,8 @@ def _run_a_job(self, torsions=job.torsions, times_rerun=job.times_rerun + int(rerun), tsg=job.tsg, + provenance_parent_job=job.job_name, + provenance_reason='rerun', xyz=job.xyz, ) @@ -1249,7 +1407,16 @@ def run_ts_conformer_jobs(self, label: str): Args: label (str): The TS species label. """ - self.species_dict[label].cluster_tsgs() + cluster_summary = self.species_dict[label].cluster_tsgs() + if cluster_summary is not None and cluster_summary['n_before'] > cluster_summary['n_after']: + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_guess_clustering, + criteria={'n_before': cluster_summary['n_before'], + 'n_after': cluster_summary['n_after'], + 'merged': cluster_summary['merged']}, + outcome=f'Clustered {cluster_summary["n_before"]} into {cluster_summary["n_after"]} unique guesses', + ) plotter.save_conformers_file( project_directory=self.project_directory, label=label, @@ -1271,7 +1438,11 @@ def run_ts_conformer_jobs(self, label: str): if not piped_indices: self.job_dict[label]['conf_opt'] = dict() for i, tsg in enumerate(successful_tsgs): - tsg.conformer_index = i # Store the conformer index to match them later. + if tsg.index is None: + existing_indices = [guess.index for guess in self.species_dict[label].ts_guesses + if guess.index is not None] + tsg.index = max(existing_indices or [-1]) + 1 + tsg.conformer_index = tsg.index if i in piped_indices: continue if 'conf_opt' not in self.job_dict[label]: @@ -1280,7 +1451,7 @@ def run_ts_conformer_jobs(self, label: str): xyz=tsg.initial_xyz, level_of_theory=self.ts_guess_level, job_type='conf_opt', - conformer=i, + conformer=tsg.index, ) elif len(successful_tsgs) == 1: if 'opt' not in self.job_dict[label].keys() and 'composite' not in self.job_dict[label].keys(): @@ -1425,6 +1596,16 @@ def run_sp_job(self, ) else: raise RuntimeError(f'Unable to set the path for the sp job for species {label}') + # ── Graph: emit energy DataNode from opt log (sp_level == opt_level) ── + if self.species_dict[label].e_elect is not None: + opt_calc_nid = self.graph.find_calc_node(label, recent_opt_job_name) \ + if recent_opt_job is not None else None + if opt_calc_nid is not None: + data_nid = self.graph.add_data_node( + label=label, data_kind=DataKind.energy, + value=round(self.species_dict[label].e_elect, 2), + metadata={'source': 'opt_log', 'note': 'SP energy parsed from opt output'}) + self.graph.add_edge(opt_calc_nid, data_nid, EdgeType.output_of) return if 'sp' not in self.job_dict[label].keys(): @@ -1708,6 +1889,7 @@ def spawn_ts_jobs(self): else: rxn.ts_species.tsg_spawned = True tsg_index = 0 + spawned_methods = [] for method in self.ts_adapters: if method in all_families_ts_adapters or \ (rxn.family is not None @@ -1719,7 +1901,21 @@ def spawn_ts_jobs(self): reactions=[rxn], tsg=tsg_index, ) + spawned_methods.append(method) tsg_index += 1 + # ── Graph: record TS method spawning decision ── + if spawned_methods: + dec_nid = self.graph.add_decision_node( + label=rxn.ts_label, + decision_kind=DecisionKind.ts_method_spawning, + criteria={'family': rxn.family, + 'all_adapters': list(self.ts_adapters), + 'spawned': spawned_methods}, + outcome=f'Spawned {len(spawned_methods)} TS guess methods', + ) + spc_nid = self.graph.find_species_node(rxn.ts_label) + if spc_nid is not None: + self.graph.add_edge(spc_nid, dec_nid, EdgeType.triggered_by) if all('user guess' in tsg.method for tsg in rxn.ts_species.ts_guesses): rxn.ts_species.tsg_spawned = True self.run_conformer_jobs(labels=[rxn.ts_label]) @@ -2083,9 +2279,14 @@ def parse_conformer(self, xyz = parser.parse_geometry(log_file_path=job.local_path_to_output_file) energy = parser.parse_e_elect(log_file_path=job.local_path_to_output_file) if self.species_dict[label].is_ts: - self.species_dict[label].ts_guesses[i].energy = energy - self.species_dict[label].ts_guesses[i].opt_xyz = xyz - self.species_dict[label].ts_guesses[i].index = i + tsg = next((guess for guess in self.species_dict[label].ts_guesses + if guess.conformer_index == i), None) + if tsg is None: + logger.warning(f'Could not find TSGuess for conformer {i} of {label} ' + f'(expected a matching conformer_index); skipping.') + return False + tsg.energy = energy + tsg.opt_xyz = xyz if energy is not None: logger.debug(f'Energy for TSGuess {i} of {label} is {energy:.2f}') else: @@ -2097,12 +2298,25 @@ def parse_conformer(self, logger.debug(f'Energy for conformer {i} of {label} is {energy:.2f}') else: logger.debug(f'Energy for conformer {i} of {label} is None') + # ── Graph: emit energy DataNode from conformer job ── + if energy is not None: + calc_nid = self.graph.find_calc_node(label, job.job_name) + if calc_nid is not None: + data_nid = self.graph.add_data_node( + label=label, data_kind=DataKind.energy, value=round(energy, 2)) + self.graph.add_edge(calc_nid, data_nid, EdgeType.output_of) else: logger.warning(f'Conformer {i} for {label} did not converge.') if job.job_status[1]['status'] == 'errored' and job.times_rerun == 0: job.times_rerun += 1 - self.troubleshoot_ess(label=label, job=job, level_of_theory=job.level, conformer= job.conformer if job.conformer is not None else None) - return True + self.troubleshoot_ess(label=label, + job=job, + level_of_theory=job.level, + conformer=job.conformer if job.conformer is not None else None) + # Report "still troubleshooting" only if another job was actually queued. + # Conformer jobs are tracked in running_jobs as '{job_type}_{conformer}', not by job_name. + running_key = f'{job.job_type}_{job.conformer}' if job.conformer is not None else job.job_name + return label in self.running_jobs and running_key in self.running_jobs[label] if job.times_rerun == 0 and self.trsh_ess_jobs: self._run_a_job(job=job, label=label, rerun=True) return True @@ -2260,6 +2474,17 @@ def determine_most_stable_conformer(self, label, sp_flag=False): self.output[label]['job_types']['conf_opt'] = True if sp_flag: self.output[label]['job_types']['conf_sp'] = True + # ── Graph: record conformer selection decision ── + selected_idx = xyzs_in_original_order.index(conformer_xyz) + non_none_energies = [(i, e) for i, e in enumerate( + self.species_dict[label].conformer_energies) if e is not None] + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.conformer_selection, + criteria={'n_conformers': len(non_none_energies), + 'isomorphic': self.species_dict[label].conf_is_isomorphic}, + outcome=f'Selected conformer #{selected_idx}', + ) def determine_most_likely_ts_conformer(self, label: str): """ @@ -2269,7 +2494,16 @@ def determine_most_likely_ts_conformer(self, label: str): Args: label (str): The TS species label. """ - self.species_dict[label].cluster_tsgs() + cluster_summary = self.species_dict[label].cluster_tsgs() + if cluster_summary is not None and cluster_summary['n_before'] > cluster_summary['n_after']: + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_guess_clustering, + criteria={'n_before': cluster_summary['n_before'], + 'n_after': cluster_summary['n_after'], + 'merged': cluster_summary['merged']}, + outcome=f'Clustered {cluster_summary["n_before"]} into {cluster_summary["n_after"]} unique guesses', + ) if not self.species_dict[label].is_ts: raise SchedulerError('determine_most_likely_ts_conformer() method only processes transition state guesses.') if not self.species_dict[label].successful_methods: @@ -2315,6 +2549,15 @@ def determine_most_likely_ts_conformer(self, label: str): logger.warning(f'Could not determine a likely TS conformer for {label}') self.species_dict[label].ts_number, self.species_dict[label].chosen_ts = None, None self.species_dict[label].populate_ts_checks() + self.record_provenance_event(event_type='ts_guess_selection_failed', + label=label, + is_ts=True, + ) + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_guess_selection_failed, + outcome='No viable TS guess found', + ) return None else: rxn_txt = '' if self.species_dict[label].rxn_label is None \ @@ -2334,6 +2577,19 @@ def determine_most_likely_ts_conformer(self, label: str): self.species_dict[label].ts_guesses_exhausted = False if getattr(tsg, 'log_path', None): self.output[label]['paths']['neb'] = tsg.log_path + self.record_provenance_event(event_type='ts_guess_selected', + label=label, + is_ts=True, + selected_index=selected_i, + method=tsg.method, + energy=tsg.energy, + ) + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_guess_selection, + criteria={'selected_index': selected_i, 'energy': tsg.energy}, + outcome=f'Selected TSGuess #{selected_i} via {tsg.method}', + ) if tsg.success and tsg.energy is not None: # guess method and ts_level opt were both successful tsg.energy -= e_min im_freqs = f', imaginary frequencies {tsg.imaginary_freqs}' if tsg.imaginary_freqs is not None else '' @@ -2513,6 +2769,8 @@ def parse_opt_geo(self, level_of_theory=job.level, job_type='opt', fine=True, + provenance_parent_job=job.job_name, + provenance_reason='fine_opt', ) else: success = True @@ -2595,6 +2853,14 @@ def check_freq_job(self, freq_ok = self.check_negative_freq(label=label, job=job, vibfreqs=vibfreqs) if freq_ok and vibfreqs is not None: self.species_dict[label].freqs = [float(f) for f in vibfreqs] + # ── Graph: emit frequencies DataNode ── + calc_nid = self.graph.find_calc_node(label, job.job_name) + if calc_nid is not None: + data_nid = self.graph.add_data_node( + label=label, data_kind=DataKind.frequencies, + value=len(vibfreqs), + metadata={'n_imaginary': sum(1 for f in vibfreqs if f < 0)}) + self.graph.add_edge(calc_nid, data_nid, EdgeType.output_of) if freq_ok: # Copy the frequency file to the species / TS output folder. folder_name = 'rxns' if self.species_dict[label].is_ts else 'Species' @@ -2621,6 +2887,12 @@ def check_freq_job(self, logger.info(f'TS {label} did not pass the normal mode displacement check. ' f'Status is:\n{self.species_dict[label].ts_checks}\n' f'Searching for a better TS conformer...') + # ── Graph: record NMD validation failure ── + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_validation_nmd, + outcome='Failed: normal mode displacement check', + ) self.switch_ts(label) switch_ts = True if wrong_freq_message in self.output[label]['warnings']: @@ -2688,10 +2960,25 @@ def check_negative_freq(self, logger.info(f'TS {label} did not pass the negative frequency check. ' f'Status is:\n{self.species_dict[label].ts_checks}\n' f'Searching for a better TS conformer...') + # ── Graph: record TS freq validation failure ── + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_validation_freq, + criteria={'neg_freqs': [float(f) for f in neg_freqs], + 'expected': 1}, + outcome=f'Failed: {len(neg_freqs)} imaginary freqs, switching TS', + ) self.switch_ts(label=label) return False else: logger.info(f'TS {label} has exactly one imaginary frequency: {neg_freqs[0]}') + # ── Graph: record TS freq validation pass ── + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_validation_freq, + criteria={'neg_freqs': [float(neg_freqs[0])]}, + outcome='Passed: exactly 1 imaginary frequency', + ) self.output[label]['info'] += f'Imaginary frequency: {neg_freqs[0] if len(neg_freqs) == 1 else neg_freqs}; ' self.output[label]['job_types']['freq'] = True self.output[label]['paths']['freq'] = job.local_path_to_output_file @@ -2761,9 +3048,19 @@ def switch_ts(self, label: str): label (str): The TS species label. """ logger.info(f'Switching a TS guess for {label}...') + old_chosen = self.species_dict[label].chosen_ts self.determine_most_likely_ts_conformer(label=label) # Look for a different TS guess. + new_chosen = self.species_dict[label].chosen_ts + # ── Graph: record TS switch decision ── + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.ts_switch, + criteria={'old_chosen': old_chosen, 'new_chosen': new_chosen, + 'exhausted': self.species_dict[label].ts_guesses_exhausted}, + outcome=f'Switched from TSG #{old_chosen} to #{new_chosen}' + if new_chosen is not None else 'All TS guesses exhausted', + ) self.delete_all_species_jobs(label=label) # Delete other currently running jobs for this TS. - self.output[label]['geo'] = self.output[label]['freq'] = self.output[label]['sp'] = self.output[label]['composite'] = '' freq_path = os.path.join(self.project_directory, 'output', 'rxns', label, 'geometry', 'freq.out') if os.path.isfile(freq_path): os.remove(freq_path) @@ -2796,6 +3093,14 @@ def check_sp_job(self, sp_path=os.path.join(job.local_path_to_output_file), level=job.level, ) + # ── Graph: emit SP energy DataNode ── + if self.species_dict[label].e_elect is not None: + calc_nid = self.graph.find_calc_node(label, job.job_name) + if calc_nid is not None: + data_nid = self.graph.add_data_node( + label=label, data_kind=DataKind.energy, + value=round(self.species_dict[label].e_elect, 2)) + self.graph.add_edge(calc_nid, data_nid, EdgeType.output_of) # Update restart dictionary and save the yaml restart file: self.save_restart_dict() if self.species_dict[label].number_of_atoms == 1: @@ -2934,10 +3239,18 @@ def check_irc_species(self, label: str): if len(self.output[ts_label]['paths']['irc']) == 2: irc_species_labels = self.species_dict[ts_label].irc_label.split() if all(self.output[irc_label]['paths']['geo'] for irc_label in irc_species_labels): - check_irc_species_and_rxn(xyz_1=self.output[irc_species_labels[0]]['paths']['geo'], - xyz_2=self.output[irc_species_labels[1]]['paths']['geo'], - rxn=self.rxn_dict.get(self.species_dict[ts_label].rxn_index, None), - ) + check_irc_species_and_rxn( + xyz_1=self.output[irc_species_labels[0]]['paths']['geo'], + xyz_2=self.output[irc_species_labels[1]]['paths']['geo'], + rxn=self.rxn_dict.get(self.species_dict[ts_label].rxn_index, None), + ) + # ── Graph: record IRC validation decision ── + self.graph.add_decision_node( + label=ts_label, + decision_kind=DecisionKind.ts_validation_irc, + criteria={'irc_species': irc_species_labels}, + outcome='IRC validation completed', + ) def check_scan_job(self, label: str, @@ -3195,6 +3508,9 @@ def check_all_done(self, label: str): logger.debug(f'Species {label} did not converge.') all_converged = False break + if all_converged and self._missing_required_paths(label): + logger.debug(f'Species {label} did not converge due to missing output paths.') + all_converged = False if label in self.output and all_converged: self.output[label]['convergence'] = True if self.species_dict[label].is_ts: @@ -3235,6 +3551,64 @@ def check_all_done(self, label: str): # Update restart dictionary and save the yaml restart file: self.save_restart_dict() + def _missing_required_paths(self, label: str) -> bool: + """ + Check whether required output paths are missing for a species/TS. + + Args: + label (str): The species label. + + Returns: + bool: Whether required output paths are missing. + """ + return bool(self._get_missing_required_paths(label)) + + def _get_missing_required_paths(self, label: str) -> set: + """ + Get missing required output path job types for a species/TS. + + Args: + label (str): The species label. + + Returns: + set: Job types with missing required output paths. + """ + if label not in self.output or 'paths' not in self.output[label]: + return set() + path_map = { + 'opt': 'geo', + 'freq': 'freq', + 'sp': 'sp', + 'composite': 'composite', + } + missing = set() + for job_type, path_key in path_map.items(): + if job_type == 'composite': + required = self.composite_method is not None + else: + required = self.job_types.get(job_type, False) + if not required: + continue + if self.species_dict[label].number_of_atoms == 1 and job_type in ['opt', 'freq']: + continue + if self.output[label]['job_types'].get(job_type, False) and not self.output[label]['paths'].get(path_key, ''): + missing.add(job_type) + return missing + + def _sanitize_restart_output(self) -> None: + """ + Ensure restart output state is internally consistent (e.g., convergence without paths). + """ + for label in list(self.output.keys()): + if label not in self.species_dict: + continue + missing_job_types = self._get_missing_required_paths(label) + if self.output[label].get('convergence') and missing_job_types: + self.output[label]['convergence'] = False + if 'job_types' in self.output[label]: + for job_type in missing_job_types: + self.output[label]['job_types'][job_type] = False + def get_server_job_ids(self, specific_server: Optional[str] = None): """ Check job status on a specific server or on all active servers, get a list of relevant running job IDs. @@ -3308,6 +3682,16 @@ def troubleshoot_negative_freq(self, logger.info(f'Deleting all currently running jobs for species {label} before troubleshooting for ' f'negative frequency with perturbed conformers...') logger.info(f'conformers:') + # ── Graph: record negative freq troubleshooting ── + trsh_nid = self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.job_troubleshooting, + criteria={'type': 'negative_freq', 'n_conformers': len(confs)}, + outcome=f'Generated {len(confs)} perturbed conformers', + ) + freq_calc_nid = self.graph.find_calc_node(label, job.job_name) + if freq_calc_nid is not None: + self.graph.add_edge(freq_calc_nid, trsh_nid, EdgeType.troubleshot_by) self.delete_all_species_jobs(label) self.species_dict[label].conformers = confs self.species_dict[label].conformer_energies = [None] * len(confs) @@ -3458,6 +3842,18 @@ def troubleshoot_scan_job(self, trsh={'scan_res': scan_res} if scan_res is not None else None, rotor_index=job.rotor_index, ) + # ── Graph: record scan troubleshooting decision ── + if trsh_success: + label = job.species_label + trsh_nid = self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.job_troubleshooting, + criteria={'type': 'scan', 'actions': actual_actions}, + outcome=f'Scan troubleshooting: {", ".join(str(k) for k in actual_actions)}', + ) + scan_calc_nid = self.graph.find_calc_node(label, job.job_name) + if scan_calc_nid is not None: + self.graph.add_edge(scan_calc_nid, trsh_nid, EdgeType.troubleshot_by) return trsh_success, actual_actions def troubleshoot_opt_jobs(self, label): @@ -3555,7 +3951,14 @@ def troubleshoot_ess(self, f'log file:\n"{job.job_status[1]["line"]}".' logger.warning(warning_message) if self.species_dict[label].is_ts and conformer is not None: - xyz = self.species_dict[label].ts_guesses[conformer].get_xyz() + tsg = next((t for t in self.species_dict[label].ts_guesses + if t.index == conformer or t.conformer_index == conformer), None) + if tsg is not None: + xyz = tsg.get_xyz() + else: + logger.warning(f'Could not find TS guess with index {conformer} for {label}; ' + f'falling back to species xyz.') + xyz = self.species_dict[label].final_xyz or self.species_dict[label].initial_xyz elif conformer is not None: xyz = self.species_dict[label].conformers[conformer] else: @@ -3597,6 +4000,27 @@ def troubleshoot_ess(self, job.ess_trsh_methods = ess_trsh_methods if not couldnt_trsh: + self.record_provenance_event(event_type='job_troubleshooting', + label=label, + is_ts=self.species_dict[label].is_ts, + job_key=f'{label}:{job.job_name}', + job_name=job.job_name, + job_type=job.job_type, + methods=ess_trsh_methods, + keywords=job.job_status[1]['keywords'], + error=job.job_status[1]['error'], + ) + # ── Graph: record troubleshooting decision ── + trsh_dec_nid = self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.job_troubleshooting, + criteria={'error_keywords': job.job_status[1]['keywords'], + 'error': job.job_status[1]['error']}, + outcome=f'Retrying with {", ".join(ess_trsh_methods[-1:])}' if ess_trsh_methods else 'Retrying', + ) + failed_calc_nid = self.graph.find_calc_node(label, job.job_name) + if failed_calc_nid is not None: + self.graph.add_edge(failed_calc_nid, trsh_dec_nid, EdgeType.troubleshot_by) self.run_job(label=label, xyz=xyz, level_of_theory=level_of_theory, @@ -3613,6 +4037,8 @@ def troubleshoot_ess(self, rotor_index=job.rotor_index, cpu_cores=cpu_cores, shift=shift, + provenance_parent_job=job.job_name, + provenance_reason='ess_troubleshoot', ) elif self.species_dict[label].is_ts and not self.species_dict[label].ts_guesses_exhausted \ and conformer is None: @@ -3622,6 +4048,17 @@ def troubleshoot_ess(self, f'Status is:\n{self.species_dict[label].ts_checks}\n' f'Searching for a better TS conformer...') self.switch_ts(label=label) + elif self.species_dict[label].is_ts and not self.species_dict[label].ts_guesses_exhausted: + # During TS conf_opt screening, avoid switching mid-batch since switch_ts() deletes all + # running jobs for this TS label and can discard other viable TS guesses still running. + if job.job_type == 'conf_opt': + logger.debug(f'Deferring TS switch for {label} during conf_opt batch screening.') + self.save_restart_dict() + return None + logger.info(f'TS {label} did not converge. ' + f'Status is:\n{self.species_dict[label].ts_checks}\n' + f'Searching for a better TS conformer...') + self.switch_ts(label=label) elif conformer is not None and couldnt_trsh: logger.warning(f'Could not troubleshoot conformer {conformer} for {label}. ' f'Abandoning this conformer; waiting for others to finish.') @@ -3662,7 +4099,13 @@ def troubleshoot_conformer_isomorphism(self, label: str): 'graph representation!; ' else: logger.info(f'Troubleshooting conformer job in {job.job_adapter} using {level_of_theory} for species {label}') - + # ── Graph: record conformer isomorphism troubleshooting ── + self.graph.add_decision_node( + label=label, + decision_kind=DecisionKind.job_troubleshooting, + criteria={'type': 'conformer_isomorphism', 'new_level': str(level_of_theory)}, + outcome=f'Rerunning {num_of_conformers} conformers at {level_of_theory}', + ) # rerun conformer job at higher level for all conformers for conformer in range(0, num_of_conformers): if conformer >= len(self.species_dict[label].conformers_before_opt): @@ -3704,7 +4147,13 @@ def delete_all_species_jobs(self, label: str): logger.info(f'Deleted job {job_name}') job.delete() self.running_jobs[label] = list() - self.output[label]['paths'] = {key: '' if key != 'irc' else list() for key in self.output[label]['paths'].keys()} + if label in self.output: + self.output[label]['convergence'] = False + for key in ['opt', 'freq', 'sp', 'composite', 'fine']: + if key in self.output[label]['job_types']: + self.output[label]['job_types'][key] = False + self.output[label]['paths'] = {key: '' if key != 'irc' else list() + for key in self.output[label]['paths'].keys()} def restore_running_jobs(self): """ @@ -3806,8 +4255,9 @@ def save_restart_dict(self): for job_name in self.running_jobs[spc.label] if 'conf_sp' in job_name] \ + [self.job_dict[spc.label]['tsg'][get_i_from_job_name(job_name)].as_dict() for job_name in self.running_jobs[spc.label] if 'tsg' in job_name] - save_yaml_file(path=self.restart_path, content=self.restart_dict) - + save_yaml_file(path=self.restart_path, content=self.restart_dict) + self.save_provenance_graph() + def make_reaction_labels_info_file(self): """ A helper function for creating the `reactions labels.info` file. diff --git a/arc/scheduler_test.py b/arc/scheduler_test.py index 3216a9f254..44ac2442b0 100644 --- a/arc/scheduler_test.py +++ b/arc/scheduler_test.py @@ -8,6 +8,7 @@ import unittest import os import shutil +from unittest import mock import arc.parser.parser as parser from arc.checks.ts import check_ts @@ -19,7 +20,7 @@ from arc.imports import settings from arc.reaction import ARCReaction from arc.species.converter import str_to_xyz -from arc.species.species import ARCSpecies +from arc.species.species import ARCSpecies, TSGuess default_levels_of_theory = settings['default_levels_of_theory'] @@ -757,13 +758,223 @@ def test_add_label_to_unique_species_labels(self): self.assertEqual(unique_label, 'new_species_15_1') self.assertEqual(self.sched2.unique_species_labels, ['methylamine', 'C2H6', 'CtripCO', 'new_species_15', 'new_species_15_0', 'new_species_15_1']) + def test_initialize_provenance_dedup_on_restart(self): + """Test that _initialize_provenance does not re-emit species_initialized for species already in the log.""" + spc = ARCSpecies(label='ethanol', smiles='CCO') + project_directory = os.path.join(ARC_PATH, 'Projects', 'arc_project_for_testing_delete_after_usage_prov') + os.makedirs(os.path.join(project_directory, 'output'), exist_ok=True) + # Write a fake provenance file that already has ethanol initialized. + from arc.common import save_yaml_file + save_yaml_file(path=os.path.join(project_directory, 'output', 'provenance.yml'), + content={'version': 1, 'project': 'test', 'run_id': 'old_run', + 'started_at': '2026-01-01T00:00:00', + 'events': [{'event_id': 1, 'event_type': 'species_initialized', + 'label': 'ethanol', 'is_ts': False}]}) + sched = Scheduler(project='test_prov_dedup', ess_settings=self.ess_settings, + species_list=[spc], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=project_directory, + testing=True, job_types=initialize_job_types()) + init_events = [e for e in sched.provenance['events'] + if e['event_type'] == 'species_initialized' and e.get('label') == 'ethanol'] + self.assertEqual(len(init_events), 1, 'species_initialized should not be duplicated on restart') + # New run should get its own run_id, not the old one. + self.assertNotEqual(sched.provenance['run_id'], 'old_run') + shutil.rmtree(project_directory, ignore_errors=True) + + def test_sanitize_restart_output(self): + """Test that _sanitize_restart_output resets convergence when paths are missing.""" + spc = ARCSpecies(label='H2O', smiles='O') + output = { + 'H2O': { + 'paths': {'geo': '', 'freq': '', 'sp': '', 'composite': ''}, + 'restart': '', 'convergence': True, + 'job_types': {'conf_opt': False, 'conf_sp': False, 'opt': True, 'freq': True, 'sp': True, + 'rotors': False, 'irc': False, 'fine': False, 'composite': False}, + } + } + sched = Scheduler(project='test_sanitize', ess_settings=self.ess_settings, + species_list=[spc], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=self.project_directory, + testing=True, job_types=initialize_job_types(), + restart_dict={'output': output}) + self.assertFalse(sched.output['H2O']['convergence']) + for key in ['opt', 'freq', 'sp']: + self.assertFalse(sched.output['H2O']['job_types'][key]) + + def test_delete_all_species_jobs_resets_output(self): + """Test that delete_all_species_jobs clears convergence, job_types, and paths.""" + spc = ARCSpecies(label='CH4', smiles='C') + output = { + 'CH4': { + 'paths': {'geo': 'some/path.out', 'freq': 'freq.out', 'sp': 'sp.out', 'composite': ''}, + 'restart': '', 'convergence': True, + 'job_types': {'conf_opt': False, 'conf_sp': False, 'opt': True, 'freq': True, 'sp': True, + 'rotors': False, 'irc': False, 'fine': True, 'composite': False}, + } + } + sched = Scheduler(project='test_delete_jobs', ess_settings=self.ess_settings, + species_list=[spc], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=self.project_directory, + testing=True, job_types=initialize_job_types(), + restart_dict={'output': output}) + sched.running_jobs['CH4'] = [] + sched.delete_all_species_jobs(label='CH4') + self.assertFalse(sched.output['CH4']['convergence']) + for key in ['opt', 'freq', 'sp', 'fine']: + self.assertFalse(sched.output['CH4']['job_types'][key]) + self.assertEqual(sched.output['CH4']['paths']['geo'], '') + + def test_conformer_index_set_before_run_job(self): + """Test that tsg.conformer_index is assigned before run_job is called, so restart state is consistent.""" + ts_spc = ARCSpecies(label='TS0', is_ts=True, multiplicity=1, charge=0) + # Use geometries different enough to survive cluster_tsgs() deduplication. + ts_spc.ts_guesses = [ + TSGuess(method='autotst', index=0, success=True, + xyz={'symbols': ('C', 'H', 'H', 'H', 'H'), 'isotopes': (12, 1, 1, 1, 1), + 'coords': ((0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1), (-1, 0, 0))}, + project_directory=self.project_directory), + TSGuess(method='gcn', index=1, success=True, + xyz={'symbols': ('C', 'H', 'H', 'H', 'H'), 'isotopes': (12, 1, 1, 1, 1), + 'coords': ((0, 0, 0), (2, 0, 0), (0, 2, 0), (0, 0, 2), (-2, 0, 0))}, + project_directory=self.project_directory), + ] + sched = Scheduler(project='test_conf_index_order', ess_settings=self.ess_settings, + species_list=[ts_spc], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + ts_guess_level=Level(repr=default_levels_of_theory['ts_guesses']), + project_directory=self.project_directory, + testing=True, job_types=initialize_job_types()) + # Track conformer_index values observed inside run_job. + observed = [] + + def capturing_run_job(**kwargs): + conformer = kwargs.get('conformer') + if conformer is not None: + tsg = next((g for g in ts_spc.ts_guesses if g.index == conformer), None) + observed.append((conformer, tsg.conformer_index if tsg else None)) + + with mock.patch.object(sched, 'run_job', side_effect=capturing_run_job), \ + mock.patch('arc.plotter.save_conformers_file'): + sched.run_ts_conformer_jobs(label='TS0') + + # Every call to run_job should have seen conformer_index already set. + self.assertTrue(len(observed) >= 2, f'Expected at least 2 conf_opt jobs, got {len(observed)}') + for conformer_idx, conformer_index_value in observed: + self.assertIsNotNone(conformer_index_value, + f'conformer_index was None when run_job was called for conformer {conformer_idx}') + self.assertEqual(conformer_idx, conformer_index_value) + + def test_provenance_records_ts_species_from_reactions(self): + """Test that TS species created from reactions get a species_initialized provenance event.""" + r_spc = ARCSpecies(label='nC3H7', smiles='[CH2]CC') + p_spc = ARCSpecies(label='iC3H7', smiles='C[CH]C') + rxn = ARCReaction(reactants=['nC3H7'], products=['iC3H7'], + r_species=[r_spc], p_species=[p_spc]) + rxn.index = 0 + sched = Scheduler(project='test_ts_prov', ess_settings=self.ess_settings, + species_list=[r_spc, p_spc], + rxn_list=[rxn], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=self.project_directory, + testing=True, job_types=initialize_job_types()) + init_labels = [e['label'] for e in sched.provenance['events'] + if e.get('event_type') == 'species_initialized'] + self.assertIn('nC3H7', init_labels) + self.assertIn('iC3H7', init_labels) + self.assertIn('TS0', init_labels, 'TS species created from a reaction should get a species_initialized event') + + def test_provenance_multi_species_label(self): + """Test that provenance handles multi-species (list) labels by joining them.""" + spc1 = ARCSpecies(label='H2', smiles='[H][H]') + spc2 = ARCSpecies(label='O2', smiles='[O][O]') + sched = Scheduler(project='test_multi_label', ess_settings=self.ess_settings, + species_list=[spc1, spc2], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=self.project_directory, + testing=True, job_types=initialize_job_types()) + sched.record_provenance_event(event_type='test_event', label='H2+O2') + event = sched.provenance['events'][-1] + self.assertEqual(event['label'], 'H2+O2') + self.assertIsInstance(event['label'], str) + + def test_provenance_graph_species_initialized(self): + """Test that the provenance graph contains species nodes after initialization.""" + spc = ARCSpecies(label='water', smiles='O') + project_directory = os.path.join(ARC_PATH, 'Projects', 'arc_project_for_testing_delete_after_usage_prov_graph') + os.makedirs(os.path.join(project_directory, 'output'), exist_ok=True) + sched = Scheduler(project='test_prov_graph', ess_settings=self.ess_settings, + species_list=[spc], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=project_directory, + testing=True, job_types=initialize_job_types()) + from arc.provenance import NodeType + species_nodes = sched.graph.get_nodes_by_type(NodeType.species) + self.assertEqual(len(species_nodes), 1) + self.assertEqual(species_nodes[0].label, 'water') + # Graph is saved lazily (at checkpoints/finalization, not per-event). + # Verify it can be saved on demand. + sched.save_provenance_graph() + self.assertTrue(os.path.isfile(sched.graph_path)) + shutil.rmtree(project_directory, ignore_errors=True) + + def test_provenance_graph_restart_preserves_nodes(self): + """Test that the provenance graph is restored correctly on restart.""" + spc = ARCSpecies(label='methane', smiles='C') + project_directory = os.path.join(ARC_PATH, 'Projects', 'arc_project_for_testing_delete_after_usage_prov_graph2') + os.makedirs(os.path.join(project_directory, 'output'), exist_ok=True) + # Create initial scheduler to write provenance files + sched1 = Scheduler(project='test_restart', ess_settings=self.ess_settings, + species_list=[spc], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=project_directory, + testing=True, job_types=initialize_job_types()) + n_nodes_before = len(sched1.graph) + self.assertGreater(n_nodes_before, 0) + sched1.save_provenance_graph() # Persist graph so the restart can load it. + # Create second scheduler on same directory (simulates restart) + sched2 = Scheduler(project='test_restart', ess_settings=self.ess_settings, + species_list=[spc], + opt_level=Level(repr=default_levels_of_theory['opt']), + freq_level=Level(repr=default_levels_of_theory['freq']), + sp_level=Level(repr=default_levels_of_theory['sp']), + project_directory=project_directory, + testing=True, job_types=initialize_job_types()) + from arc.provenance import NodeType + species_nodes = sched2.graph.get_nodes_by_type(NodeType.species) + # Should still have exactly 1 species node (no duplicate) + self.assertEqual(len(species_nodes), 1) + self.assertEqual(species_nodes[0].label, 'methane') + shutil.rmtree(project_directory, ignore_errors=True) + @classmethod def tearDownClass(cls): """ A function that is run ONCE after all unit tests in this class. Delete all project directories created during these unit tests """ - projects = ['arc_project_for_testing_delete_after_usage3', 'arc_project_for_testing_delete_after_usage6'] + projects = ['arc_project_for_testing_delete_after_usage3', 'arc_project_for_testing_delete_after_usage6', + 'arc_project_for_testing_delete_after_usage_prov', + 'arc_project_for_testing_delete_after_usage_prov_graph', + 'arc_project_for_testing_delete_after_usage_prov_graph2'] for project in projects: project_directory = os.path.join(ARC_PATH, 'Projects', project) shutil.rmtree(project_directory, ignore_errors=True) diff --git a/arc/species/species.py b/arc/species/species.py index 76a9105dfd..485651eb4c 100644 --- a/arc/species/species.py +++ b/arc/species/species.py @@ -1540,12 +1540,12 @@ def make_ts_report(self): self.ts_report += ':\n' if self.successful_methods: self.ts_report += 'Methods that successfully generated a TS guess:\n' - for successful_method in self.successful_methods: - self.ts_report += successful_method + ',' + unique_successful_methods = list(dict.fromkeys(self.successful_methods)) + self.ts_report += ','.join(unique_successful_methods) if self.unsuccessful_methods: - self.ts_report += '\nMethods that were unsuccessfully in generating a TS guess:\n' - for unsuccessful_method in self.unsuccessful_methods: - self.ts_report += unsuccessful_method + ',' + self.ts_report += '\nMethods that were unsuccessful in generating a TS guess:\n' + unique_unsuccessful_methods = list(dict.fromkeys(self.unsuccessful_methods)) + self.ts_report += ','.join(unique_unsuccessful_methods) if not self.ts_guesses_exhausted: self.ts_report += f'\nThe method that generated the best TS guess and its output used for the ' \ f'optimization: {self.chosen_ts_method}\n' @@ -1553,6 +1553,11 @@ def make_ts_report(self): def cluster_tsgs(self): """ Cluster TSGuesses. + + Returns: + Optional[dict]: ``None`` if this species is not a TS or has no TS guesses. + Otherwise a summary dict with keys ``n_before``, ``n_after``, and + ``merged`` (list of lists of merged indices). """ if not self.is_ts or not len(self.ts_guesses): return None @@ -1574,6 +1579,9 @@ def cluster_tsgs(self): if len(cluster_tsgs) < n_before: logger.info(f'Clustered {n_before} TS guesses for {self.label} ' f'into {len(cluster_tsgs)} unique conformers.') + return {'n_before': n_before, + 'n_after': len(cluster_tsgs), + 'merged': [tsg.cluster for tsg in cluster_tsgs if len(tsg.cluster) > 1]} def process_completed_tsg_queue_jobs(self, path: str): """ diff --git a/arc/species/species_test.py b/arc/species/species_test.py index 466217ff36..6bc884a2e8 100644 --- a/arc/species/species_test.py +++ b/arc/species/species_test.py @@ -1201,7 +1201,7 @@ def test_from_dict(self): 'ts_guesses_exhausted': False, 'ts_number': 0, 'ts_report': 'TS method summary for TS0 in C3_1 <=> C3_2:\n' 'Methods that successfully generated a TS guess:\n' - 'autotst,autotst,autotst,autotst,gcn,gcn,gcn,gcn,gcn,gcn,gcn,gcn,gcn,gcn,kinbot,kinbot,\n' + 'autotst,gcn,kinbot\n' 'The method that generated the best TS guess and its output used ' 'for the optimization: gcn\n', 'tsg_spawned': True, 'unsuccessful_methods': []} diff --git a/environment.yml b/environment.yml index e885465c0f..fef93f1630 100644 --- a/environment.yml +++ b/environment.yml @@ -24,6 +24,7 @@ dependencies: - conda-forge::ffmpeg - conda-forge::gprof2dot - conda-forge::graphviz + - conda-forge::python-graphviz - conda-forge::h5py - conda-forge::ipython - conda-forge::jupyter