From a0b236890fa3af29f8c9df5279dccb3c35d877f1 Mon Sep 17 00:00:00 2001 From: Alon Grinberg Dana Date: Thu, 2 Apr 2026 00:15:45 +0300 Subject: [PATCH 1/4] Added the ND Scan module Contains pure scan-surface helpers --- arc/species/nd_scan.py | 2040 +++++++++++++++++++++++++++++++++++ arc/species/nd_scan_test.py | 1469 +++++++++++++++++++++++++ 2 files changed, 3509 insertions(+) create mode 100644 arc/species/nd_scan.py create mode 100644 arc/species/nd_scan_test.py diff --git a/arc/species/nd_scan.py b/arc/species/nd_scan.py new file mode 100644 index 0000000000..a09fb21128 --- /dev/null +++ b/arc/species/nd_scan.py @@ -0,0 +1,2040 @@ +""" +A module for N-dimensional (ND) rotor scan utilities. + +Contains pure scan-surface helpers: +grid generation, point iteration, continuous-scan state management, +energy normalization, and adaptive sparse 2D scan logic. +Does **not** own job orchestration or species metadata -- +those remain in Scheduler. +""" + +import itertools +import math +from typing import Dict, Iterator, List, Optional, Tuple, Union + +from arc.common import extremum_list, get_angle_in_180_range, get_logger +from arc.exceptions import InputError, SchedulerError +from arc.species.vectors import calculate_dihedral_angle + +logger = get_logger() + + +ADAPTIVE_DEFAULT_BATCH_SIZE = 10 +ADAPTIVE_DEFAULT_MAX_POINTS = 200 +ADAPTIVE_DEFAULT_MIN_POINTS = 20 + +VALIDATION_ENERGY_JUMP_THRESHOLD = 30.0 # kJ/mol between adjacent grid points +VALIDATION_GEOMETRY_RMSD_THRESHOLD = 1.5 # Angstrom, distance-matrix RMSD +VALIDATION_PERIODIC_ENERGY_THRESHOLD = 5.0 # kJ/mol mismatch across wraparound +VALIDATION_PERIODIC_RMSD_THRESHOLD = 1.0 # Angstrom across wraparound +VALIDATION_BRANCH_JUMP_EDGE_COUNT = 2 # min suspicious edges to flag a point + + +def validate_scan_resolution(increment: float) -> None: + """ + Validate that the scan resolution divides 360 evenly. + + Args: + increment (float): The scan resolution in degrees. + + Raises: + SchedulerError: If the increment is not positive or does not divide 360 evenly. + """ + if increment <= 0: + raise SchedulerError(f'The directed scan got a non-positive scan resolution of {increment}') + quotient = 360.0 / increment + if not math.isclose(quotient, round(quotient), abs_tol=1e-9): + raise SchedulerError(f'The directed scan got an illegal scan resolution of {increment}') + + +def get_torsion_dihedral_grid(xyz: dict, + torsions: list, + increment: float, + ) -> Dict[Tuple[int, ...], List[float]]: + """ + Build the per-torsion list of dihedral angles for a brute-force scan. + + For each torsion in ``torsions``, computes the current dihedral from ``xyz`` + and generates a list of ``int(360 / increment) + 1`` evenly spaced angles + starting from that dihedral, each wrapped into the -180..+180 range. + + Args: + xyz (dict): The 3D coordinates (ARC xyz dict with 'coords' key). + torsions (list): List of torsion definitions (each a list of 4 atom indices, 0-indexed). + increment (float): The scan resolution in degrees. + + Returns: + dict: Keys are torsion tuples, values are lists of dihedral angles. + """ + dihedrals = dict() + for torsion in torsions: + original_dihedral = get_angle_in_180_range( + calculate_dihedral_angle(coords=xyz['coords'], torsion=torsion, index=0)) + dihedrals[tuple(torsion)] = [ + get_angle_in_180_range(original_dihedral + i * increment) + for i in range(int(360 / increment) + 1) + ] + return dihedrals + + +def iter_brute_force_scan_points(dihedrals_by_torsion: Dict[Tuple[int, ...], List[float]], + torsions: list, + diagonal: bool = False, + ) -> Iterator[Tuple[float, ...]]: + """ + Yield dihedral-angle tuples for every point in a brute-force scan. + + Args: + dihedrals_by_torsion (dict): Mapping ``{torsion_tuple: [angle_0, angle_1, ...]}`` + as returned by :func:`get_torsion_dihedral_grid`. + torsions (list): Ordered list of torsion definitions (each a list of 4 ints). + diagonal (bool, optional): If ``True``, all torsions are incremented + simultaneously (1-D diagonal path through ND space). + If ``False`` (default), the full cartesian product is generated. + + Yields: + tuple: A tuple of dihedral angles, one per torsion, in the order of ``torsions``. + """ + if not diagonal: + for combo in itertools.product(*[dihedrals_by_torsion[tuple(t)] for t in torsions]): + yield combo + else: + n_points = len(dihedrals_by_torsion[tuple(torsions[0])]) + for i in range(n_points): + yield tuple(dihedrals_by_torsion[tuple(t)][i] for t in torsions) + + +def initialize_continuous_scan_state(rotor_dict: dict, + xyz: dict, + ) -> None: + """ + Initialize the continuous-scan bookkeeping fields on a rotor dict + (``cont_indices`` and ``original_dihedrals``) if they have not been set yet. + + Modifies ``rotor_dict`` **in place**. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + xyz (dict): The 3D coordinates (ARC xyz dict). + """ + torsions = rotor_dict['torsion'] + if not len(rotor_dict['cont_indices']): + rotor_dict['cont_indices'] = [0] * len(torsions) + if not len(rotor_dict['original_dihedrals']): + rotor_dict['original_dihedrals'] = [ + f'{calculate_dihedral_angle(coords=xyz["coords"], torsion=scan, index=1):.2f}' + for scan in rotor_dict['scan'] + ] # stored as str for YAML compatibility + + +def get_continuous_scan_dihedrals(rotor_dict: dict, + increment: float, + ) -> List[float]: + """ + Compute the dihedral angles for the *current* continuous-scan step, + based on ``cont_indices`` and ``original_dihedrals`` stored in the rotor dict. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + + Returns: + list: A list of dihedral angles (one per torsion) for this step. + """ + dihedrals = [] + for index, original_dihedral_str in enumerate(rotor_dict['original_dihedrals']): + original_dihedral = get_angle_in_180_range(float(original_dihedral_str)) + dihedral = original_dihedral + rotor_dict['cont_indices'][index] * increment + dihedral = get_angle_in_180_range(dihedral) + dihedrals.append(dihedral) + return dihedrals + + +def is_continuous_scan_complete(rotor_dict: dict, + increment: float, + ) -> bool: + """ + Check whether a continuous directed scan has visited every grid point. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + + Returns: + bool: ``True`` if the scan is complete (all counters exhausted). + """ + max_num = 360 / increment + 1 # dihedral angles per dimension + return rotor_dict['cont_indices'][-1] == max_num - 1 # 0-indexed + + +def increment_continuous_scan_indices(rotor_dict: dict, + increment: float, + diagonal: bool = False, + ) -> None: + """ + Advance the continuous-scan counters by one step. + + For a diagonal scan every counter is incremented together. + For a non-diagonal scan the counters are incremented like an odometer + (innermost dimension first). + + Modifies ``rotor_dict['cont_indices']`` **in place**. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + diagonal (bool, optional): Whether this is a diagonal scan. + """ + torsions = rotor_dict['torsion'] + max_num = 360 / increment + 1 + + if diagonal: + rotor_dict['cont_indices'] = [rotor_dict['cont_indices'][0] + 1] * len(torsions) + else: + for index in range(len(torsions)): + if rotor_dict['cont_indices'][index] < max_num - 1: + rotor_dict['cont_indices'][index] += 1 + break + elif rotor_dict['cont_indices'][index] == max_num - 1 and index < len(torsions) - 1: + rotor_dict['cont_indices'][index] = 0 + + +def normalize_directed_scan_energies(rotor_dict: dict) -> Tuple[dict, int]: + """ + Build a ``results`` dict for a non-ESS directed scan and normalize energies + so that the minimum is zero. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + Must contain ``'directed_scan'``, ``'directed_scan_type'``, and ``'scan'`` keys. + + Returns: + tuple: A two-element tuple: + - results (dict): ``{'directed_scan_type': ..., 'scans': ..., 'directed_scan': ...}`` + with energies shifted so the minimum is 0. + - trshed_points (int): Number of scan points that required troubleshooting. + """ + dihedrals = [[float(d) for d in key] for key in rotor_dict['directed_scan'].keys()] + sorted_dihedrals = sorted(dihedrals) + min_energy = extremum_list( + [entry['energy'] for entry in rotor_dict['directed_scan'].values()], + return_min=True, + ) + results = { + 'directed_scan_type': rotor_dict['directed_scan_type'], + 'scans': rotor_dict['scan'], + 'directed_scan': rotor_dict['directed_scan'], + } + trshed_points = 0 + for dihedral_list in sorted_dihedrals: + key = tuple(f'{d:.2f}' for d in dihedral_list) + dihedral_dict = results['directed_scan'][key] + if dihedral_dict['trsh']: + trshed_points += 1 + if dihedral_dict['energy'] is not None and min_energy is not None: + dihedral_dict['energy'] -= min_energy + return results, trshed_points + + +def format_dihedral_key(dihedrals: list) -> Tuple[str, ...]: + """ + Build the legacy string-tuple key used to index ``rotor_dict['directed_scan']``. + + Args: + dihedrals (list): A list of dihedral angles (floats). + + Returns: + tuple: A tuple of ``'{angle:.2f}'`` strings, one per dihedral. + """ + return tuple(f'{dihedral:.2f}' for dihedral in dihedrals) + + +def record_directed_scan_point(rotor_dict: dict, + dihedrals: list, + energy: Optional[float], + xyz: Optional[dict], + is_isomorphic: bool, + trsh: list, + ) -> None: + """ + Record a single completed directed-scan point into the legacy + ``rotor_dict['directed_scan']`` structure. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + dihedrals (list): The dihedral angles that define this scan point. + energy (float or None): The electronic energy (absolute, un-normalized). + xyz (dict or None): The optimized geometry for this point. + is_isomorphic (bool): Whether the optimized geometry is isomorphic with the species graph. + trsh (list): Troubleshooting methods applied to this point. + """ + key = format_dihedral_key(dihedrals) + rotor_dict['directed_scan'][key] = { + 'energy': energy, + 'xyz': xyz, + 'is_isomorphic': is_isomorphic, + 'trsh': trsh, + } + + +def get_rotor_dict_by_pivots(rotors_dict: dict, + pivots: Union[List[int], List[List[int]]], + ) -> Optional[Tuple[int, dict]]: + """ + Look up a rotor dict entry by its pivots. + + Args: + rotors_dict (dict): The full ``species.rotors_dict`` mapping. + pivots: The pivot(s) to match against. + + Returns: + tuple or None: ``(rotor_index, rotor_dict)`` if found, else ``None``. + """ + for rotor_index, rotor_dict in rotors_dict.items(): + if rotor_dict['pivots'] == pivots: + return rotor_index, rotor_dict + return None + + +def finalize_directed_scan_results(rotor_dict: dict, + parse_nd_scan_energies_func=None, + increment: Optional[float] = None, + ) -> Tuple[dict, int]: + """ + Produce the final results payload for a completed directed scan. + + For ESS-controlled scans (``directed_scan_type == 'ess'``), delegates to the + parser via ``parse_nd_scan_energies_func``. For brute-force and continuous scans, + normalizes energies so the minimum is zero and counts troubleshot points. + + For adaptive 2D scans, also runs surface validation if ``increment`` is provided. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + parse_nd_scan_energies_func (callable, optional): A callable that takes + ``log_file_path`` and returns a list whose first element is the + results dict. Only needed for ESS scans. Pass + ``parser.parse_nd_scan_energies`` from the caller to avoid importing + the parser here (which would create a circular import through + ``arc.__init__``). + increment (float, optional): The scan resolution in degrees. If provided + and the scan is adaptive, surface validation is run. + + Returns: + tuple: ``(results, trshed_points)`` where *results* has the structure + consumed by ``plotter.save_nd_rotor_yaml``, ``plotter.plot_1d_rotor_scan``, + and ``plotter.plot_2d_rotor_scan``. + """ + if rotor_dict['directed_scan_type'] == 'ess': + if parse_nd_scan_energies_func is None: + raise ValueError('parse_nd_scan_energies_func must be provided for ESS directed scans') + results = parse_nd_scan_energies_func(log_file_path=rotor_dict['scan_path'])[0] + return results, 0 + results, trshed_points = normalize_directed_scan_energies(rotor_dict) + # Attach optional sparse metadata for adaptive scans (non-breaking addition) + if is_adaptive_enabled(rotor_dict): + state = rotor_dict.get('adaptive_scan', {}) + results['sampling_policy'] = 'adaptive' + results['adaptive_scan_summary'] = { + 'completed_count': len(state.get('completed_points', [])), + 'failed_count': len(state.get('failed_points', [])), + 'invalid_count': len(state.get('invalid_points', [])), + 'stopping_reason': state.get('stopping_reason'), + 'failed_points': [list(p) for p in state.get('failed_points', [])], + 'invalid_points': [list(p) for p in state.get('invalid_points', [])], + } + # Run surface validation, coupling metrics, and classification if increment is available + if increment is not None: + update_adaptive_validation_state(rotor_dict, increment) + update_nd_classification(rotor_dict, increment) + validation = state.get('validation', {}) + results['validation_summary'] = { + 'discontinuous_edges': len(validation.get('discontinuous_edges', [])), + 'periodic_inconsistencies': len(validation.get('periodic_inconsistencies', [])), + 'branch_jump_points': len(validation.get('branch_jump_points', [])), + 'status': validation.get('status', 'not_run'), + 'thresholds': validation.get('thresholds', {}), + } + coupling = state.get('coupling_metrics', {}) + results['coupling_summary'] = { + 'nonseparability_score': coupling.get('nonseparability_score'), + 'cross_term_strength': coupling.get('cross_term_strength'), + 'status': coupling.get('status', 'not_run'), + } + quality = state.get('surface_quality', {}) + results['surface_quality_summary'] = { + 'quality_score': quality.get('quality_score'), + 'coverage_fraction': quality.get('coverage_fraction'), + 'status': quality.get('status', 'not_run'), + } + nd_cls = state.get('nd_classification', {}) + results['classification_summary'] = { + 'classification': nd_cls.get('classification'), + 'confidence': nd_cls.get('confidence'), + 'recommended_action': nd_cls.get('recommended_action'), + 'reason': nd_cls.get('reason'), + } + return results, trshed_points + + +def decrement_running_jobs(rotor_dict: dict) -> bool: + """ + Decrement the brute-force running-jobs counter and return whether all jobs + for this rotor have finished. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + + Returns: + bool: ``True`` if all brute-force jobs for this rotor have terminated + (counter reached 0). + """ + rotor_dict['number_of_running_jobs'] -= 1 + if rotor_dict['number_of_running_jobs'] < 0: + logger.warning(f'Running jobs counter went below zero ' + f'({rotor_dict["number_of_running_jobs"]}), clamping to 0.') + rotor_dict['number_of_running_jobs'] = 0 + return rotor_dict['number_of_running_jobs'] == 0 + + +# =========================================================================== +# Adaptive sparse 2D scan helpers +# =========================================================================== + +def _angular_distance(p1: Tuple[float, float], p2: Tuple[float, float]) -> float: + """ + Compute the Euclidean distance between two 2D angle points + using periodic-aware differences on each dimension. + + Args: + p1 (tuple): First point ``(phi0, phi1)`` in degrees. + p2 (tuple): Second point ``(phi0, phi1)`` in degrees. + + Returns: + float: The distance in degrees. + """ + d0 = abs(p1[0] - p2[0]) % 360.0 + d0 = min(d0, 360.0 - d0) + d1 = abs(p1[1] - p2[1]) % 360.0 + d1 = min(d1, 360.0 - d1) + return math.hypot(d0, d1) + + +def _normalize_angle_key(phi: float) -> float: + """Wrap an angle into -180..+180 and round to 2 decimals.""" + return round(get_angle_in_180_range(phi), 2) + + +def point_to_key(point: Tuple[float, float]) -> Tuple[str, ...]: + """Convert a 2D angle tuple to the normalized legacy string key.""" + return tuple(f'{_normalize_angle_key(a):.2f}' for a in point) + + +# --------------------------------------------------------------------------- +# Policy / eligibility +# --------------------------------------------------------------------------- + +def is_adaptive_eligible(rotor_dict: dict) -> bool: + """ + Check whether a rotor dict is eligible for adaptive sparse scanning. + + Eligibility requires: + * ``directed_scan_type`` is ``'brute_force_sp'`` or ``'brute_force_opt'`` + * ``dimensions`` == 2 + * not a diagonal scan type + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + + Returns: + bool: ``True`` if the rotor is eligible for adaptive scanning. + """ + dst = rotor_dict.get('directed_scan_type', '') + if dst not in ('brute_force_sp', 'brute_force_opt'): + return False + if rotor_dict.get('dimensions', 1) != 2: + return False + if 'diagonal' in dst: + return False + return True + + +def is_adaptive_enabled(rotor_dict: dict) -> bool: + """ + Check whether adaptive scanning is both eligible and enabled for a rotor. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + + Returns: + bool: ``True`` if the rotor should use adaptive sparse scanning. + """ + if not is_adaptive_eligible(rotor_dict): + return False + policy = rotor_dict.get('sampling_policy', 'dense') + return policy == 'adaptive' + + +# --------------------------------------------------------------------------- +# Adaptive state initialization +# --------------------------------------------------------------------------- + +def _make_empty_adaptive_state(batch_size: int = ADAPTIVE_DEFAULT_BATCH_SIZE, + max_points: Optional[int] = ADAPTIVE_DEFAULT_MAX_POINTS, + min_points: int = ADAPTIVE_DEFAULT_MIN_POINTS, + ) -> dict: + """Return a fresh, YAML-serializable adaptive_scan state dict.""" + return { + 'enabled': True, + 'phase': 'seed', + 'batch_size': batch_size, + 'candidate_points': list(), + 'pending_points': list(), + 'completed_points': list(), + 'failed_points': list(), + 'invalid_points': list(), + 'seed_points': list(), + 'selected_points_history': list(), + 'stopping_reason': None, + 'max_points': max_points, + 'min_points': min_points, + 'fit_metadata': dict(), + 'surface_model': dict(), + } + + +def initialize_adaptive_scan_state(rotor_dict: dict, + xyz: dict, + increment: float, + batch_size: int = ADAPTIVE_DEFAULT_BATCH_SIZE, + max_points: Optional[int] = ADAPTIVE_DEFAULT_MAX_POINTS, + min_points: int = ADAPTIVE_DEFAULT_MIN_POINTS, + ) -> None: + """ + Initialize adaptive scan state on a rotor dict if it does not already exist. + + Also generates the deterministic seed points and stores them. + Modifies ``rotor_dict`` **in place**. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + xyz (dict): The 3D coordinates (ARC xyz dict). + increment (float): The scan resolution in degrees. + batch_size (int): Number of points to submit per adaptive batch. + max_points (int or None): Maximum total points before stopping. + min_points (int): Minimum points before stopping is allowed. + """ + if 'adaptive_scan' in rotor_dict and rotor_dict['adaptive_scan'].get('enabled', False): + return # already initialized + state = _make_empty_adaptive_state(batch_size=batch_size, max_points=max_points, min_points=min_points) + rotor_dict['adaptive_scan'] = state + rotor_dict['sampling_policy'] = 'adaptive' + # Populate original_dihedrals from current geometry so grid origin is consistent + # between seed generation and later candidate generation. + torsions = rotor_dict['torsion'] + if not rotor_dict.get('original_dihedrals'): + rotor_dict['original_dihedrals'] = [ + f'{_normalize_angle_key(calculate_dihedral_angle(coords=xyz["coords"], torsion=t, index=0)):.2f}' + for t in torsions + ] + seeds = generate_adaptive_seed_points(rotor_dict, xyz, increment) + state['seed_points'] = [list(s) for s in seeds] + + +# --------------------------------------------------------------------------- +# Seed generation +# --------------------------------------------------------------------------- + +def generate_adaptive_seed_points(rotor_dict: dict, + xyz: dict, + increment: float, + ) -> List[Tuple[float, float]]: + """ + Generate a deterministic set of seed points for an adaptive 2D scan. + + The seed includes: + 1. The current-geometry point. + 2. A coarse grid at 3x the base increment (covering the full 2D surface sparsely). + 3. Two 1D cuts along each dimension through the current-geometry point. + + All angles are normalized to -180..+180. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + xyz (dict): The 3D coordinates (ARC xyz dict). + increment (float): The base scan resolution in degrees. + + Returns: + list: A list of ``(phi0, phi1)`` tuples (deduplicated). + """ + torsions = rotor_dict['torsion'] + if len(torsions) != 2: + raise InputError(f'Adaptive seed generation requires exactly 2 torsions, got {len(torsions)}') + + # Current geometry dihedral values + orig_0 = _normalize_angle_key( + calculate_dihedral_angle(coords=xyz['coords'], torsion=torsions[0], index=0)) + orig_1 = _normalize_angle_key( + calculate_dihedral_angle(coords=xyz['coords'], torsion=torsions[1], index=0)) + + n_fine = int(360 / increment) + 1 + fine_angles_0 = [_normalize_angle_key(orig_0 + i * increment) for i in range(n_fine)] + fine_angles_1 = [_normalize_angle_key(orig_1 + i * increment) for i in range(n_fine)] + + # Coarse grid: every 3rd step of the fine grid (ensures manageable seed count) + coarse_step = 3 + coarse_0 = fine_angles_0[::coarse_step] + coarse_1 = fine_angles_1[::coarse_step] + + seen = set() + seeds = [] + + def _add(p): + key = point_to_key(p) + if key not in seen: + seen.add(key) + seeds.append((float(key[0]), float(key[1]))) + + # 1. Current conformation + _add((orig_0, orig_1)) + + # 2. Coarse grid + for a0 in coarse_0: + for a1 in coarse_1: + _add((a0, a1)) + + # 3. 1D cuts through origin along each dimension + for a0 in fine_angles_0: + _add((a0, orig_1)) + for a1 in fine_angles_1: + _add((orig_0, a1)) + + return seeds + + +# --------------------------------------------------------------------------- +# Bookkeeping helpers +# --------------------------------------------------------------------------- + +def mark_scan_points_pending(rotor_dict: dict, points: List[list]) -> None: + """ + Add points to the pending list in the adaptive state. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + points (list): Points to mark pending (each a 2-element list of floats). + """ + state = rotor_dict['adaptive_scan'] + pending_keys = {point_to_key(tuple(p)) for p in state['pending_points']} + for p in points: + key = point_to_key(tuple(p)) + if key not in pending_keys: + state['pending_points'].append(list(p)) + pending_keys.add(key) + + +def mark_scan_point_completed(rotor_dict: dict, + point: list, + energy: Optional[float], + xyz: Optional[dict], + is_isomorphic: bool, + trsh: list, + ) -> None: + """ + Record a completed adaptive scan point. + + Moves the point from pending to completed and writes + into the legacy ``directed_scan`` structure. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + point (list): The 2D dihedral angles. + energy: The electronic energy. + xyz: The optimized geometry. + is_isomorphic (bool): Isomorphism check result. + trsh (list): Troubleshooting methods applied. + """ + state = rotor_dict['adaptive_scan'] + key = point_to_key(tuple(point)) + # Remove from pending + state['pending_points'] = [p for p in state['pending_points'] + if point_to_key(tuple(p)) != key] + # Add to completed (if not already there) + if not any(point_to_key(tuple(c)) == key for c in state['completed_points']): + state['completed_points'].append(list(point)) + # Also write into legacy directed_scan + record_directed_scan_point(rotor_dict, point, energy, xyz, is_isomorphic, trsh) + + +def mark_scan_point_failed(rotor_dict: dict, + point: list, + reason: Optional[str] = None, + ) -> None: + """ + Record a failed adaptive scan point. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + point (list): The 2D dihedral angles. + reason (str, optional): Reason for failure. + """ + state = rotor_dict['adaptive_scan'] + key = point_to_key(tuple(point)) + state['pending_points'] = [p for p in state['pending_points'] + if point_to_key(tuple(p)) != key] + if not any(point_to_key(tuple(f)) == key for f in state['failed_points']): + state['failed_points'].append(list(point)) + + +def mark_scan_point_invalid(rotor_dict: dict, + point: list, + reason: Optional[str] = None, + ) -> None: + """ + Record an invalid adaptive scan point (e.g. non-isomorphic). + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + point (list): The 2D dihedral angles. + reason (str, optional): Reason for invalidation. + """ + state = rotor_dict['adaptive_scan'] + key = point_to_key(tuple(point)) + state['pending_points'] = [p for p in state['pending_points'] + if point_to_key(tuple(p)) != key] + if not any(point_to_key(tuple(inv)) == key for inv in state['invalid_points']): + state['invalid_points'].append(list(point)) + + +def get_completed_adaptive_points(rotor_dict: dict) -> List[list]: + """Return the completed points list from adaptive state.""" + return rotor_dict.get('adaptive_scan', {}).get('completed_points', []) + + +def get_pending_adaptive_points(rotor_dict: dict) -> List[list]: + """Return the pending points list from adaptive state.""" + return rotor_dict.get('adaptive_scan', {}).get('pending_points', []) + + +def _all_visited_keys(rotor_dict: dict) -> set: + """Return the set of string-tuple keys for all visited/submitted points.""" + state = rotor_dict.get('adaptive_scan', {}) + keys = set() + for lst_name in ('completed_points', 'pending_points', 'failed_points', 'invalid_points'): + for p in state.get(lst_name, []): + keys.add(point_to_key(tuple(p))) + return keys + + +# --------------------------------------------------------------------------- +# Surrogate / surface model +# --------------------------------------------------------------------------- + +def fit_adaptive_surface_model(rotor_dict: dict) -> dict: + """ + Fit a lightweight RBF-like interpolation model from completed scan points. + + The model is an inverse-distance-weighted (IDW) interpolation on + periodic 2D angle space. The returned dict is YAML-serializable and + contains only the data needed to evaluate predictions. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + + Returns: + dict: A model dict with keys ``'centers'``, ``'values'``, ``'length_scale'``. + """ + state = rotor_dict['adaptive_scan'] + directed = rotor_dict['directed_scan'] + centers = [] + values = [] + for pt in state['completed_points']: + key = point_to_key(tuple(pt)) + entry = directed.get(key, None) + if entry is not None and entry.get('energy') is not None: + centers.append([float(pt[0]), float(pt[1])]) + values.append(float(entry['energy'])) + model = { + 'type': 'idw', + 'centers': centers, + 'values': values, + 'length_scale': 30.0, # degrees; controls smoothing + } + rotor_dict['adaptive_scan']['surface_model'] = model + rotor_dict['adaptive_scan']['fit_metadata'] = { + 'n_points': len(centers), + } + return model + + +def predict_surface_values(model_dict: dict, query_points: List[list]) -> List[Optional[float]]: + """ + Predict energy values at query points using the fitted model. + + Uses inverse-distance weighting with periodic angular distance. + + Args: + model_dict (dict): A model dict as returned by :func:`fit_adaptive_surface_model`. + query_points (list): List of ``[phi0, phi1]`` query points. + + Returns: + list: Predicted energy values (``None`` if the model has no data). + """ + centers = model_dict.get('centers', []) + values = model_dict.get('values', []) + length_scale = model_dict.get('length_scale', 30.0) + + if not centers: + return [None] * len(query_points) + + predictions = [] + for qp in query_points: + weights = [] + for c in centers: + d = _angular_distance(tuple(qp), tuple(c)) + if d < 1e-8: + weights.append(1e12) # essentially exact match + else: + weights.append(1.0 / (d / length_scale) ** 2) + total_w = sum(weights) + if total_w < 1e-30: + predictions.append(None) + else: + pred = sum(w * v for w, v in zip(weights, values)) / total_w + predictions.append(pred) + return predictions + + +def score_candidate_points(rotor_dict: dict, candidate_points: List[list]) -> List[float]: + """ + Score candidate points for adaptive acquisition. + + The score is a combination of: + 1. **Distance score**: How far the candidate is from the nearest sampled point + (prefer points in under-sampled regions). + 2. **Energy score**: Lower predicted energy is mildly preferred + (explore low-energy regions more). + + Higher score means higher priority for selection. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + candidate_points (list): Candidate ``[phi0, phi1]`` points. + + Returns: + list: A score for each candidate (higher = more desirable). + """ + state = rotor_dict['adaptive_scan'] + model = state.get('surface_model', {}) + + # Collect all sampled centers + sampled = [] + for pt in state.get('completed_points', []): + sampled.append(tuple(pt)) + for pt in state.get('failed_points', []): + sampled.append(tuple(pt)) + + predictions = predict_surface_values(model, candidate_points) + + scores = [] + for i, cp in enumerate(candidate_points): + # Distance to nearest sampled point + if sampled: + min_dist = min(_angular_distance(tuple(cp), s) for s in sampled) + else: + min_dist = 360.0 # max possible + + dist_score = min_dist / 360.0 # normalize to [0, 1]-ish + + # Energy preference: lower predicted energy -> mild bonus + energy_score = 0.0 + pred = predictions[i] + if pred is not None and model.get('values'): + e_range = max(model['values']) - min(model['values']) if len(model['values']) > 1 else 1.0 + if e_range > 1e-10: + energy_score = 0.2 * (1.0 - (pred - min(model['values'])) / e_range) + + scores.append(dist_score + energy_score) + return scores + + +# --------------------------------------------------------------------------- +# Candidate generation & selection +# --------------------------------------------------------------------------- + +def generate_adaptive_candidate_points(rotor_dict: dict, increment: float) -> List[list]: + """ + Generate the full set of candidate grid points that have not been visited. + + Returns all points on the full dense grid that are not yet in any + visited set (completed, pending, failed, invalid). + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + + Returns: + list: Unvisited ``[phi0, phi1]`` points on the full grid. + """ + visited = _all_visited_keys(rotor_dict) + + # Reconstruct the full grid angles from the seed data or from first principles. + # Use original_dihedrals if available, else start from 0. + orig_dihedrals = rotor_dict.get('original_dihedrals', []) + if orig_dihedrals and len(orig_dihedrals) == 2: + start_0, start_1 = float(orig_dihedrals[0]), float(orig_dihedrals[1]) + else: + start_0, start_1 = 0.0, 0.0 + + n = int(360 / increment) + 1 + angles_0 = [_normalize_angle_key(start_0 + i * increment) for i in range(n)] + angles_1 = [_normalize_angle_key(start_1 + i * increment) for i in range(n)] + + candidates = [] + seen = set() + for a0 in angles_0: + for a1 in angles_1: + key = point_to_key((a0, a1)) + if key not in visited and key not in seen: + seen.add(key) + candidates.append([float(key[0]), float(key[1])]) + return candidates + + +def select_next_adaptive_points(rotor_dict: dict, + increment: float, + batch_size: Optional[int] = None, + ) -> List[list]: + """ + Select the next batch of points to submit for an adaptive scan. + + If the scan is in the ``'seed'`` phase, returns the unsubmitted seed points. + Otherwise fits a surrogate, scores candidates, and returns the top-scoring batch. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + batch_size (int, optional): Override the batch size from adaptive state. + + Returns: + list: Selected ``[phi0, phi1]`` points. + """ + state = rotor_dict['adaptive_scan'] + bs = batch_size if batch_size is not None else state.get('batch_size', ADAPTIVE_DEFAULT_BATCH_SIZE) + visited = _all_visited_keys(rotor_dict) + + if state['phase'] == 'seed': + # Return seed points that haven't been submitted yet + unsubmitted = [] + for s in state['seed_points']: + key = point_to_key(tuple(s)) + if key not in visited: + unsubmitted.append(s) + # Transition to adaptive phase once all seeds are dispatched + if len(unsubmitted) <= bs: + state['phase'] = 'adaptive' + return unsubmitted[:bs] + + # Adaptive phase: fit model, generate candidates, score & select + candidates = generate_adaptive_candidate_points(rotor_dict, increment) + if not candidates: + return [] + + fit_adaptive_surface_model(rotor_dict) + scores = score_candidate_points(rotor_dict, candidates) + + # Sort by score descending, take top batch_size + indexed = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) + selected = [candidates[i] for i, _ in indexed[:bs]] + + state['selected_points_history'].append([list(p) for p in selected]) + return selected + + +# --------------------------------------------------------------------------- +# Stopping logic +# --------------------------------------------------------------------------- + +def should_continue_adaptive_scan(rotor_dict: dict, increment: float) -> bool: + """ + Determine whether the adaptive scan should continue submitting new batches. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + + Returns: + bool: ``True`` if more batches should be submitted. + """ + reason = get_adaptive_stopping_reason(rotor_dict, increment) + if reason is not None: + rotor_dict['adaptive_scan']['stopping_reason'] = reason + return False + return True + + +def get_adaptive_stopping_reason(rotor_dict: dict, increment: float) -> Optional[str]: + """ + Return the stopping reason, or ``None`` if the scan should continue. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + + Returns: + str or None: Reason for stopping, or ``None``. + """ + state = rotor_dict['adaptive_scan'] + n_completed = len(state['completed_points']) + n_pending = len(state['pending_points']) + + # Already stopped + if state.get('stopping_reason') is not None: + return state['stopping_reason'] + + # Max points reached + max_pts = state.get('max_points') + if max_pts is not None and (n_completed + n_pending) >= max_pts: + return 'max_points_reached' + + # No more candidates on the grid + candidates = generate_adaptive_candidate_points(rotor_dict, increment) + if not candidates and n_pending == 0: + return 'grid_exhausted' + + # All grid points have been visited (full coverage). + # Use int(360/increment) per dimension (not +1) because angle normalization + # maps +180 to -180, so the endpoint duplicates the start. + n_grid = int(360 / increment) ** 2 + n_visited = len(_all_visited_keys(rotor_dict)) + if n_visited >= n_grid: + return 'full_coverage' + + # Min points check: don't stop before reaching min_points + min_pts = state.get('min_points', ADAPTIVE_DEFAULT_MIN_POINTS) + if n_completed < min_pts: + return None + + return None + + +def is_adaptive_scan_complete(rotor_dict: dict, increment: float) -> bool: + """ + Check if an adaptive scan is fully complete (stopped and no pending jobs). + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + + Returns: + bool: ``True`` if the adaptive scan has stopped and has no pending points. + """ + state = rotor_dict.get('adaptive_scan', {}) + if not state.get('enabled', False): + return False + n_pending = len(state.get('pending_points', [])) + if n_pending > 0: + return False + if state.get('stopping_reason') is not None: + return True + return not should_continue_adaptive_scan(rotor_dict, increment) + + +# =========================================================================== +# Surface validation for adaptive 2D scans +# =========================================================================== + + +def _make_empty_validation_state() -> dict: + """Return a fresh, YAML-serializable validation state dict.""" + return { + 'enabled': True, + 'status': 'not_run', + 'neighbor_edges_checked': 0, + 'discontinuous_edges': [], + 'periodic_edges_checked': 0, + 'periodic_inconsistencies': [], + 'branch_jump_points': [], + 'energy_jump_summary': {}, + 'geometry_rmsd_summary': {}, + 'thresholds': { + 'energy_jump': VALIDATION_ENERGY_JUMP_THRESHOLD, + 'geometry_rmsd': VALIDATION_GEOMETRY_RMSD_THRESHOLD, + 'periodic_energy': VALIDATION_PERIODIC_ENERGY_THRESHOLD, + 'periodic_rmsd': VALIDATION_PERIODIC_RMSD_THRESHOLD, + 'branch_jump_edge_count': VALIDATION_BRANCH_JUMP_EDGE_COUNT, + }, + 'notes': [], + } + + +def _periodic_neighbor_offsets(increment: float) -> List[Tuple[float, float]]: + """Return the 4 cardinal neighbor offsets for a 2D grid.""" + return [(increment, 0.0), (-increment, 0.0), (0.0, increment), (0.0, -increment)] + + +def get_sampled_point_neighbors(rotor_dict: dict, + point: list, + increment: float, + ) -> List[list]: + """ + Return sampled neighboring points of ``point`` on the 2D scan grid. + + Neighbors are the 4 cardinal grid-adjacent points (±increment on each axis) + that exist in the completed scan data. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + point (list): ``[phi0, phi1]`` in degrees. + increment (float): The scan resolution in degrees. + + Returns: + list: Neighboring points that have completed scan data. + """ + directed = rotor_dict.get('directed_scan', {}) + neighbors = [] + for d0, d1 in _periodic_neighbor_offsets(increment): + nb = [_normalize_angle_key(point[0] + d0), _normalize_angle_key(point[1] + d1)] + key = point_to_key(tuple(nb)) + if key in directed: + neighbors.append(nb) + return neighbors + + +def iter_sampled_neighbor_edges(rotor_dict: dict, + increment: float, + ) -> Iterator[Tuple[list, list]]: + """ + Yield unique pairs of neighboring sampled points for validation. + + Each edge ``(point_a, point_b)`` is yielded exactly once, where both + points have completed scan data. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + + Yields: + tuple: ``(point_a, point_b)`` where each is ``[phi0, phi1]``. + """ + directed = rotor_dict.get('directed_scan', {}) + seen_edges = set() + for key_tuple in directed.keys(): + pt = [float(key_tuple[0]), float(key_tuple[1])] + for d0, d1 in _periodic_neighbor_offsets(increment): + nb = [_normalize_angle_key(pt[0] + d0), _normalize_angle_key(pt[1] + d1)] + nb_key = point_to_key(tuple(nb)) + if nb_key in directed: + edge = tuple(sorted([point_to_key(tuple(pt)), nb_key])) + if edge not in seen_edges: + seen_edges.add(edge) + yield pt, nb + + +def calculate_neighbor_energy_jump(rotor_dict: dict, + point_a: list, + point_b: list, + ) -> Optional[float]: + """ + Compute the absolute energy difference between two neighboring scan points. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + point_a (list): ``[phi0, phi1]`` first point. + point_b (list): ``[phi0, phi1]`` second point. + + Returns: + float or None: Absolute energy difference in kJ/mol, or ``None`` if + either point lacks energy data. + """ + directed = rotor_dict.get('directed_scan', {}) + key_a = point_to_key(tuple(point_a)) + key_b = point_to_key(tuple(point_b)) + entry_a = directed.get(key_a) + entry_b = directed.get(key_b) + if entry_a is None or entry_b is None: + return None + e_a = entry_a.get('energy') + e_b = entry_b.get('energy') + if e_a is None or e_b is None: + return None + return abs(float(e_a) - float(e_b)) + + +def calculate_neighbor_geometry_rmsd(rotor_dict: dict, + point_a: list, + point_b: list, + ) -> Optional[float]: + """ + Compute the distance-matrix RMSD between the optimized geometries of two + neighboring scan points. + + Uses the full molecular geometry (all atoms). This is a lightweight proxy + for detecting branch jumps where non-rotor atoms rearrange significantly. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + point_a (list): ``[phi0, phi1]`` first point. + point_b (list): ``[phi0, phi1]`` second point. + + Returns: + float or None: The RMSD of the two distance matrices (Angstrom), or + ``None`` if either point lacks geometry data. + """ + from arc.species.converter import compare_confs + directed = rotor_dict.get('directed_scan', {}) + key_a = point_to_key(tuple(point_a)) + key_b = point_to_key(tuple(point_b)) + entry_a = directed.get(key_a) + entry_b = directed.get(key_b) + if entry_a is None or entry_b is None: + return None + xyz_a = entry_a.get('xyz') + xyz_b = entry_b.get('xyz') + if not isinstance(xyz_a, dict) or not isinstance(xyz_b, dict): + return None + if 'coords' not in xyz_a or 'coords' not in xyz_b: + return None + try: + return compare_confs(xyz_a, xyz_b, rmsd_score=True) + except Exception: + return None + + +def classify_neighbor_edge_continuity(rotor_dict: dict, + point_a: list, + point_b: list, + energy_threshold: float = VALIDATION_ENERGY_JUMP_THRESHOLD, + rmsd_threshold: float = VALIDATION_GEOMETRY_RMSD_THRESHOLD, + ) -> dict: + """ + Classify a neighbor edge as continuous or suspicious. + + An edge is flagged as discontinuous if: + - energy jump exceeds ``energy_threshold``, OR + - geometry RMSD exceeds ``rmsd_threshold`` + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + point_a (list): ``[phi0, phi1]`` first point. + point_b (list): ``[phi0, phi1]`` second point. + energy_threshold (float): Max acceptable energy jump (kJ/mol). + rmsd_threshold (float): Max acceptable distance-matrix RMSD (Angstrom). + + Returns: + dict: Classification with keys ``'continuous'`` (bool), ``'energy_jump'``, + ``'geometry_rmsd'``, ``'reasons'`` (list of str). + """ + e_jump = calculate_neighbor_energy_jump(rotor_dict, point_a, point_b) + g_rmsd = calculate_neighbor_geometry_rmsd(rotor_dict, point_a, point_b) + reasons = [] + if e_jump is not None and e_jump > energy_threshold: + reasons.append(f'energy_jump={e_jump:.2f}') + if g_rmsd is not None and g_rmsd > rmsd_threshold: + reasons.append(f'geometry_rmsd={g_rmsd:.4f}') + return { + 'continuous': len(reasons) == 0, + 'energy_jump': round(e_jump, 4) if e_jump is not None else None, + 'geometry_rmsd': round(g_rmsd, 6) if g_rmsd is not None else None, + 'reasons': reasons, + } + + +def _is_periodic_edge(point_a: list, point_b: list, increment: float) -> bool: + """Check whether an edge between two points wraps across the -180/+180 boundary.""" + for i in range(2): + diff = abs(point_a[i] - point_b[i]) + if diff > 360.0 - 1.5 * increment: + return True + return False + + +def check_periodic_edge_consistency(rotor_dict: dict, + point_a: list, + point_b: list, + energy_threshold: float = VALIDATION_PERIODIC_ENERGY_THRESHOLD, + rmsd_threshold: float = VALIDATION_PERIODIC_RMSD_THRESHOLD, + ) -> dict: + """ + Check consistency of an edge that wraps across the periodic boundary. + + Periodic edges should have similar energies/geometries if the surface + is well-behaved across the -180/+180 wrap. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + point_a (list): ``[phi0, phi1]`` first point. + point_b (list): ``[phi0, phi1]`` second point (the wrap partner). + energy_threshold (float): Max acceptable energy mismatch (kJ/mol). + rmsd_threshold (float): Max acceptable geometry RMSD (Angstrom). + + Returns: + dict: With keys ``'consistent'`` (bool), ``'energy_mismatch'``, + ``'geometry_rmsd'``, ``'reasons'`` (list of str). + """ + e_jump = calculate_neighbor_energy_jump(rotor_dict, point_a, point_b) + g_rmsd = calculate_neighbor_geometry_rmsd(rotor_dict, point_a, point_b) + reasons = [] + if e_jump is not None and e_jump > energy_threshold: + reasons.append(f'periodic_energy_mismatch={e_jump:.2f}') + if g_rmsd is not None and g_rmsd > rmsd_threshold: + reasons.append(f'periodic_geometry_mismatch={g_rmsd:.4f}') + return { + 'consistent': len(reasons) == 0, + 'energy_mismatch': round(e_jump, 4) if e_jump is not None else None, + 'geometry_rmsd': round(g_rmsd, 6) if g_rmsd is not None else None, + 'reasons': reasons, + } + + +def detect_branch_jump_points(rotor_dict: dict, + increment: float, + energy_threshold: float = VALIDATION_ENERGY_JUMP_THRESHOLD, + rmsd_threshold: float = VALIDATION_GEOMETRY_RMSD_THRESHOLD, + min_suspicious_edges: int = VALIDATION_BRANCH_JUMP_EDGE_COUNT, + ) -> List[list]: + """ + Detect points suspected of being on a different PES branch. + + A point is flagged if it is connected to ``>= min_suspicious_edges`` + discontinuous neighbor edges. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + energy_threshold (float): Energy jump threshold for edge classification. + rmsd_threshold (float): Geometry RMSD threshold for edge classification. + min_suspicious_edges (int): Minimum suspicious edges to flag a point. + + Returns: + list: Flagged points ``[[phi0, phi1], ...]``. + """ + suspicious_count = {} # key_str -> count + for pt_a, pt_b in iter_sampled_neighbor_edges(rotor_dict, increment): + classification = classify_neighbor_edge_continuity( + rotor_dict, pt_a, pt_b, energy_threshold, rmsd_threshold) + if not classification['continuous']: + for pt in [pt_a, pt_b]: + k = point_to_key(tuple(pt)) + suspicious_count[k] = suspicious_count.get(k, 0) + 1 + flagged = [] + for k, count in suspicious_count.items(): + if count >= min_suspicious_edges: + flagged.append([float(k[0]), float(k[1])]) + return flagged + + +def run_adaptive_surface_validation(rotor_dict: dict, + increment: float, + ) -> dict: + """ + Compute a full surface-validation summary for an adaptive 2D scan. + + Checks all sampled neighbor edges for energy and geometry continuity, + identifies periodic wraparound inconsistencies, and flags suspected + branch-jump points. Results are stored in a YAML-serializable dict. + + This function does **not** modify stored scan data or energies. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + + Returns: + dict: The validation state dict (also stored in + ``rotor_dict['adaptive_scan']['validation']``). + """ + validation = _make_empty_validation_state() + thresholds = validation['thresholds'] + + # --- Neighbor edge continuity --- + discontinuous = [] + energy_jumps = [] + geometry_rmsds = [] + n_edges = 0 + + for pt_a, pt_b in iter_sampled_neighbor_edges(rotor_dict, increment): + n_edges += 1 + cl = classify_neighbor_edge_continuity( + rotor_dict, pt_a, pt_b, + thresholds['energy_jump'], thresholds['geometry_rmsd']) + if cl['energy_jump'] is not None: + energy_jumps.append(cl['energy_jump']) + if cl['geometry_rmsd'] is not None: + geometry_rmsds.append(cl['geometry_rmsd']) + if not cl['continuous']: + discontinuous.append({ + 'point_a': [round(x, 2) for x in pt_a], + 'point_b': [round(x, 2) for x in pt_b], + 'energy_jump': cl['energy_jump'], + 'geometry_rmsd': cl['geometry_rmsd'], + 'reasons': cl['reasons'], + }) + + validation['neighbor_edges_checked'] = n_edges + validation['discontinuous_edges'] = discontinuous + + if energy_jumps: + validation['energy_jump_summary'] = { + 'min': round(min(energy_jumps), 4), + 'max': round(max(energy_jumps), 4), + 'mean': round(sum(energy_jumps) / len(energy_jumps), 4), + 'count': len(energy_jumps), + } + if geometry_rmsds: + validation['geometry_rmsd_summary'] = { + 'min': round(min(geometry_rmsds), 6), + 'max': round(max(geometry_rmsds), 6), + 'mean': round(sum(geometry_rmsds) / len(geometry_rmsds), 6), + 'count': len(geometry_rmsds), + } + + # --- Periodic edge consistency --- + periodic_issues = [] + n_periodic = 0 + directed = rotor_dict.get('directed_scan', {}) + for key_tuple in directed.keys(): + pt = [float(key_tuple[0]), float(key_tuple[1])] + for d0, d1 in _periodic_neighbor_offsets(increment): + nb = [_normalize_angle_key(pt[0] + d0), _normalize_angle_key(pt[1] + d1)] + if _is_periodic_edge(pt, nb, increment): + nb_key = point_to_key(tuple(nb)) + if nb_key in directed: + n_periodic += 1 + pc = check_periodic_edge_consistency( + rotor_dict, pt, nb, + thresholds['periodic_energy'], thresholds['periodic_rmsd']) + if not pc['consistent']: + periodic_issues.append({ + 'point_a': [round(x, 2) for x in pt], + 'point_b': [round(x, 2) for x in nb], + 'energy_mismatch': pc['energy_mismatch'], + 'geometry_rmsd': pc['geometry_rmsd'], + 'reasons': pc['reasons'], + }) + + validation['periodic_edges_checked'] = n_periodic + validation['periodic_inconsistencies'] = periodic_issues + + # --- Branch-jump detection --- + flagged = detect_branch_jump_points( + rotor_dict, increment, + thresholds['energy_jump'], thresholds['geometry_rmsd'], + thresholds['branch_jump_edge_count']) + validation['branch_jump_points'] = flagged + + # --- Status --- + if n_edges == 0: + validation['status'] = 'no_edges' + validation['notes'].append('No neighbor edges found; too few sampled points for validation.') + else: + validation['status'] = 'complete' + if discontinuous: + validation['notes'].append( + f'{len(discontinuous)} of {n_edges} neighbor edges are discontinuous.') + if periodic_issues: + validation['notes'].append( + f'{len(periodic_issues)} periodic boundary inconsistencies found.') + if flagged: + validation['notes'].append( + f'{len(flagged)} points suspected of branch jumps.') + if not discontinuous and not periodic_issues and not flagged: + validation['notes'].append('Surface passed all continuity checks.') + + return validation + + +def update_adaptive_validation_state(rotor_dict: dict, + increment: float, + ) -> None: + """ + Run surface validation and store results in the rotor's adaptive state. + + Only runs for adaptive 2D brute-force scans. Does nothing for dense or + other scan types. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + """ + if not is_adaptive_enabled(rotor_dict): + return + state = rotor_dict.get('adaptive_scan', {}) + if not state.get('enabled', False): + return + validation = run_adaptive_surface_validation(rotor_dict, increment) + state['validation'] = validation + # Log summary + n_disc = len(validation.get('discontinuous_edges', [])) + n_periodic = len(validation.get('periodic_inconsistencies', [])) + n_branch = len(validation.get('branch_jump_points', [])) + if n_disc or n_periodic or n_branch: + logger.warning(f'Adaptive scan surface validation: ' + f'{n_disc} discontinuous edges, ' + f'{n_periodic} periodic inconsistencies, ' + f'{n_branch} branch-jump suspects.') + + +# =========================================================================== +# Coupling metrics, surface quality, and ND classification +# =========================================================================== + +# Thresholds for coupling classification (V1 defaults) +COUPLING_NONSEP_THRESHOLD = 0.15 # Relative separable-fit error above this → coupled +COUPLING_CROSS_TERM_THRESHOLD = 0.10 # Cross-term fraction above this → coupled +QUALITY_MIN_POINTS = 9 # Minimum completed points for any analysis +QUALITY_FAILED_FRACTION_LIMIT = 0.20 # Above this → unreliable +QUALITY_INVALID_FRACTION_LIMIT = 0.15 # Above this → unreliable +QUALITY_DISC_EDGE_FRACTION_LIMIT = 0.25 # Above this → unreliable + + +def extract_adaptive_2d_surface_arrays(rotor_dict: dict) -> dict: + """ + Extract sampled 2D coordinates and energies into numpy arrays. + + Only includes completed points that have non-None energy. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + + Returns: + dict: ``{'phi0': np.array, 'phi1': np.array, 'energy': np.array, 'n_points': int}`` + """ + import numpy as np + directed = rotor_dict.get('directed_scan', {}) + phi0_list, phi1_list, energy_list = [], [], [] + for key, entry in directed.items(): + e = entry.get('energy') + if e is not None and len(key) == 2: + phi0_list.append(float(key[0])) + phi1_list.append(float(key[1])) + energy_list.append(float(e)) + return { + 'phi0': np.array(phi0_list, dtype=np.float64), + 'phi1': np.array(phi1_list, dtype=np.float64), + 'energy': np.array(energy_list, dtype=np.float64), + 'n_points': len(phi0_list), + } + + +def fit_separable_surface_proxy(surface_data: dict) -> dict: + """ + Build a simple separable approximation E(phi0, phi1) ≈ f(phi0) + g(phi1) + c. + + The separable components are estimated by discrete averaging: + f(phi0_i) = mean_over_phi1 { E(phi0_i, phi1_j) } - c + g(phi1_j) = mean_over_phi0 { E(phi0_i, phi1_j) } - c + c = overall mean of all sampled energies + + This works on irregularly sampled data by grouping points by their + phi0 and phi1 keys. + + Args: + surface_data (dict): As returned by :func:`extract_adaptive_2d_surface_arrays`. + + Returns: + dict: ``{'c': float, 'f_values': dict, 'g_values': dict, 'separable_predictions': np.array}`` + where ``f_values[phi0_key] = f(phi0)`` and ``g_values[phi1_key] = g(phi1)``. + """ + import numpy as np + phi0 = surface_data['phi0'] + phi1 = surface_data['phi1'] + energy = surface_data['energy'] + n = surface_data['n_points'] + + if n == 0: + return {'c': 0.0, 'f_values': {}, 'g_values': {}, 'separable_predictions': np.array([])} + + c = float(np.mean(energy)) + + # Group energies by phi0 key and phi1 key + phi0_groups = {} # phi0_str -> list of energies + phi1_groups = {} # phi1_str -> list of energies + for i in range(n): + k0 = f'{phi0[i]:.2f}' + k1 = f'{phi1[i]:.2f}' + phi0_groups.setdefault(k0, []).append(energy[i]) + phi1_groups.setdefault(k1, []).append(energy[i]) + + f_values = {k: float(np.mean(v)) - c for k, v in phi0_groups.items()} + g_values = {k: float(np.mean(v)) - c for k, v in phi1_groups.items()} + + # Build separable predictions at each sampled point + sep_pred = np.zeros(n, dtype=np.float64) + for i in range(n): + k0 = f'{phi0[i]:.2f}' + k1 = f'{phi1[i]:.2f}' + sep_pred[i] = f_values.get(k0, 0.0) + g_values.get(k1, 0.0) + c + + return { + 'c': c, + 'f_values': f_values, + 'g_values': g_values, + 'separable_predictions': sep_pred, + } + + +def calculate_separable_fit_error(surface_data: dict, separable_fit: dict) -> float: + """ + Calculate the RMS error of the separable fit relative to the energy range. + + Returns the RMSE of (E_actual - E_separable) normalized by the range of E_actual. + A value near 0 means the surface is well-described by a separable model. + + Args: + surface_data (dict): As returned by :func:`extract_adaptive_2d_surface_arrays`. + separable_fit (dict): As returned by :func:`fit_separable_surface_proxy`. + + Returns: + float: Normalized RMSE (dimensionless). Returns 0.0 if insufficient data. + """ + import numpy as np + energy = surface_data['energy'] + sep_pred = separable_fit['separable_predictions'] + if len(energy) < 2 or len(sep_pred) < 2: + return 0.0 + residuals = energy - sep_pred + rmse = float(np.sqrt(np.mean(residuals ** 2))) + e_range = float(np.max(energy) - np.min(energy)) + if e_range < 1e-10: + return 0.0 + return rmse / e_range + + +def calculate_nonseparability_score(surface_data: dict, separable_fit: dict) -> float: + """ + Compute a nonseparability score: the fraction of total variance NOT explained + by the separable model. + + Score near 0 → separable; score near 1 → strongly coupled. + + Formula: 1 - R² where R² = 1 - SS_res / SS_tot. + + Args: + surface_data (dict): As returned by :func:`extract_adaptive_2d_surface_arrays`. + separable_fit (dict): As returned by :func:`fit_separable_surface_proxy`. + + Returns: + float: Nonseparability score in [0, 1]. Returns 0.0 if insufficient data. + """ + import numpy as np + energy = surface_data['energy'] + sep_pred = separable_fit['separable_predictions'] + if len(energy) < 3: + return 0.0 + ss_tot = float(np.sum((energy - np.mean(energy)) ** 2)) + if ss_tot < 1e-10: + return 0.0 + ss_res = float(np.sum((energy - sep_pred) ** 2)) + r_squared = 1.0 - ss_res / ss_tot + return max(0.0, min(1.0, 1.0 - r_squared)) + + +def calculate_cross_term_strength(surface_data: dict, separable_fit: dict) -> float: + """ + Estimate the strength of cross-term coupling as the fraction of total energy + variance attributable to the non-separable residual. + + This is essentially the same as the nonseparability score but expressed as + the ratio of residual variance to total variance. + + Args: + surface_data (dict): As returned by :func:`extract_adaptive_2d_surface_arrays`. + separable_fit (dict): As returned by :func:`fit_separable_surface_proxy`. + + Returns: + float: Cross-term strength fraction in [0, 1]. + """ + return calculate_nonseparability_score(surface_data, separable_fit) + + +def calculate_low_energy_path_coupling(surface_data: dict) -> float: + """ + Heuristic for low-energy-path coupling: measures whether the minimum-energy + path through the 2D surface is axis-aligned (separable) or diagonal (coupled). + + Method: among the lowest 25% of energy points, compute the correlation + coefficient between phi0 and phi1. High |correlation| suggests diagonal + low-energy valleys → coupling. + + Args: + surface_data (dict): As returned by :func:`extract_adaptive_2d_surface_arrays`. + + Returns: + float: Absolute correlation of phi0 and phi1 among low-energy points, in [0, 1]. + Returns 0.0 if insufficient data. + """ + import numpy as np + n = surface_data['n_points'] + if n < 4: + return 0.0 + energy = surface_data['energy'] + phi0 = surface_data['phi0'] + phi1 = surface_data['phi1'] + + # Select the lowest 25% of points + threshold = np.percentile(energy, 25) + mask = energy <= threshold + if mask.sum() < 3: + return 0.0 + + # Compute correlation between phi0 and phi1 in the low-energy subset + # Use sin/cos to handle periodicity + sin0 = np.sin(np.radians(phi0[mask])) + cos0 = np.cos(np.radians(phi0[mask])) + sin1 = np.sin(np.radians(phi1[mask])) + cos1 = np.cos(np.radians(phi1[mask])) + + # Cross-correlation: max of |corr(sin0,sin1)|, |corr(sin0,cos1)|, etc. + max_corr = 0.0 + for a in [sin0, cos0]: + for b in [sin1, cos1]: + if np.std(a) > 1e-10 and np.std(b) > 1e-10: + corr = abs(float(np.corrcoef(a, b)[0, 1])) + if not np.isnan(corr): + max_corr = max(max_corr, corr) + return max_corr + + +def compute_coupling_metrics(rotor_dict: dict) -> dict: + """ + Compute all coupling metrics for an adaptive 2D scan. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + + Returns: + dict: The coupling_metrics dict (also stored in ``rotor_dict['adaptive_scan']``). + """ + metrics = { + 'enabled': True, + 'status': 'not_run', + 'nonseparability_score': None, + 'cross_term_strength': None, + 'low_energy_path_coupling': None, + 'separable_fit_error': None, + 'coupled_fit_proxy': None, + 'thresholds': { + 'nonseparability': COUPLING_NONSEP_THRESHOLD, + 'cross_term': COUPLING_CROSS_TERM_THRESHOLD, + }, + 'notes': [], + } + + surface = extract_adaptive_2d_surface_arrays(rotor_dict) + if surface['n_points'] < QUALITY_MIN_POINTS: + metrics['status'] = 'insufficient_data' + metrics['notes'].append(f'Only {surface["n_points"]} points; need >= {QUALITY_MIN_POINTS}.') + return metrics + + sep_fit = fit_separable_surface_proxy(surface) + metrics['nonseparability_score'] = round(calculate_nonseparability_score(surface, sep_fit), 6) + metrics['cross_term_strength'] = round(calculate_cross_term_strength(surface, sep_fit), 6) + metrics['separable_fit_error'] = round(calculate_separable_fit_error(surface, sep_fit), 6) + metrics['low_energy_path_coupling'] = round(calculate_low_energy_path_coupling(surface), 6) + metrics['coupled_fit_proxy'] = metrics['nonseparability_score'] + metrics['status'] = 'complete' + return metrics + + +def update_coupling_metrics(rotor_dict: dict) -> None: + """ + Compute and store coupling metrics on the rotor's adaptive state. + + Only runs for adaptive 2D brute-force scans. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + """ + if not is_adaptive_enabled(rotor_dict): + return + metrics = compute_coupling_metrics(rotor_dict) + rotor_dict.setdefault('adaptive_scan', {})['coupling_metrics'] = metrics + + +# --------------------------------------------------------------------------- +# Surface quality metrics +# --------------------------------------------------------------------------- + +def calculate_coverage_fraction(rotor_dict: dict, increment: float) -> float: + """ + Fraction of the full dense grid that has been visited (completed/failed/invalid). + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + + Returns: + float: Coverage fraction in [0, 1]. + """ + # Use int(360/increment) per dimension (not +1) because angle normalization + # maps +180 to -180, so the endpoint duplicates the start. + n_grid = int(360 / increment) ** 2 + if n_grid == 0: + return 0.0 + n_visited = len(_all_visited_keys(rotor_dict)) + return min(1.0, n_visited / n_grid) + + +def calculate_failed_fraction(rotor_dict: dict) -> float: + """Fraction of submitted points that failed.""" + state = rotor_dict.get('adaptive_scan', {}) + total = (len(state.get('completed_points', [])) + + len(state.get('failed_points', [])) + + len(state.get('invalid_points', []))) + if total == 0: + return 0.0 + return len(state.get('failed_points', [])) / total + + +def calculate_invalid_fraction(rotor_dict: dict) -> float: + """Fraction of submitted points that are invalid (non-isomorphic).""" + state = rotor_dict.get('adaptive_scan', {}) + total = (len(state.get('completed_points', [])) + + len(state.get('failed_points', [])) + + len(state.get('invalid_points', []))) + if total == 0: + return 0.0 + return len(state.get('invalid_points', [])) / total + + +def calculate_validation_warning_fraction(rotor_dict: dict) -> float: + """ + Fraction of checked neighbor edges that were flagged as discontinuous. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + + Returns: + float: Warning fraction in [0, 1]. + """ + validation = rotor_dict.get('adaptive_scan', {}).get('validation', {}) + n_edges = validation.get('neighbor_edges_checked', 0) + if n_edges == 0: + return 0.0 + n_disc = len(validation.get('discontinuous_edges', [])) + return n_disc / n_edges + + +def calculate_periodic_consistency_score(rotor_dict: dict) -> float: + """ + Periodic consistency as (1 - fraction of inconsistent periodic edges). + + Returns 1.0 if all periodic edges are consistent or if no edges were checked. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + + Returns: + float: Score in [0, 1]. Higher is better. + """ + validation = rotor_dict.get('adaptive_scan', {}).get('validation', {}) + n_periodic = validation.get('periodic_edges_checked', 0) + if n_periodic == 0: + return 1.0 + n_issues = len(validation.get('periodic_inconsistencies', [])) + return 1.0 - n_issues / n_periodic + + +def calculate_overall_quality_score(rotor_dict: dict, increment: float) -> float: + """ + Compute a composite surface quality score in [0, 1]. + + Weighted combination: + - 30% coverage fraction + - 25% (1 - failed fraction) + - 20% (1 - invalid fraction) + - 15% (1 - validation warning fraction) + - 10% periodic consistency score + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + + Returns: + float: Quality score in [0, 1]. Higher is better. + """ + cov = calculate_coverage_fraction(rotor_dict, increment) + fail = calculate_failed_fraction(rotor_dict) + inv = calculate_invalid_fraction(rotor_dict) + warn = calculate_validation_warning_fraction(rotor_dict) + per = calculate_periodic_consistency_score(rotor_dict) + return 0.30 * cov + 0.25 * (1.0 - fail) + 0.20 * (1.0 - inv) + 0.15 * (1.0 - warn) + 0.10 * per + + +def compute_surface_quality_metrics(rotor_dict: dict, increment: float) -> dict: + """ + Compute all surface quality metrics for an adaptive 2D scan. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + + Returns: + dict: The surface_quality dict. + """ + state = rotor_dict.get('adaptive_scan', {}) + n_completed = len(state.get('completed_points', [])) + metrics = { + 'enabled': True, + 'status': 'complete', + 'coverage_fraction': round(calculate_coverage_fraction(rotor_dict, increment), 4), + 'completed_fraction': None, + 'failed_fraction': round(calculate_failed_fraction(rotor_dict), 4), + 'invalid_fraction': round(calculate_invalid_fraction(rotor_dict), 4), + 'validation_warning_fraction': round(calculate_validation_warning_fraction(rotor_dict), 4), + 'periodic_consistency_score': round(calculate_periodic_consistency_score(rotor_dict), 4), + 'quality_score': round(calculate_overall_quality_score(rotor_dict, increment), 4), + 'thresholds': { + 'min_points': QUALITY_MIN_POINTS, + 'failed_fraction_limit': QUALITY_FAILED_FRACTION_LIMIT, + 'invalid_fraction_limit': QUALITY_INVALID_FRACTION_LIMIT, + 'disc_edge_fraction_limit': QUALITY_DISC_EDGE_FRACTION_LIMIT, + }, + 'notes': [], + } + total = (n_completed + len(state.get('failed_points', [])) + + len(state.get('invalid_points', []))) + if total > 0: + metrics['completed_fraction'] = round(n_completed / total, 4) + if n_completed < QUALITY_MIN_POINTS: + metrics['status'] = 'insufficient_data' + metrics['notes'].append(f'Only {n_completed} completed points; need >= {QUALITY_MIN_POINTS}.') + return metrics + + +def update_surface_quality_metrics(rotor_dict: dict, increment: float) -> None: + """ + Compute and store surface quality metrics on the rotor's adaptive state. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + """ + if not is_adaptive_enabled(rotor_dict): + return + metrics = compute_surface_quality_metrics(rotor_dict, increment) + rotor_dict.setdefault('adaptive_scan', {})['surface_quality'] = metrics + + +# --------------------------------------------------------------------------- +# ND rotor classification +# --------------------------------------------------------------------------- + +def classify_adaptive_nd_rotor(rotor_dict: dict, increment: float) -> dict: + """ + Classify an adaptive 2D ND rotor as separable, coupled, or unreliable. + + Logic: + 1. If surface quality is insufficient or too many failures → ``"unreliable"`` + 2. If quality is acceptable and nonseparability is below threshold → ``"separable"`` + 3. If quality is acceptable and nonseparability is above threshold → ``"coupled"`` + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + + Returns: + dict: The nd_classification dict. + """ + result = { + 'enabled': True, + 'status': 'not_run', + 'classification': None, + 'confidence': None, + 'reason': None, + 'recommended_action': None, + 'notes': [], + } + + # Ensure we have quality and coupling metrics + state = rotor_dict.get('adaptive_scan', {}) + quality = state.get('surface_quality', {}) + coupling = state.get('coupling_metrics', {}) + + # Check if we have enough data + if quality.get('status') == 'insufficient_data' or coupling.get('status') == 'insufficient_data': + result['status'] = 'insufficient_data' + result['classification'] = 'unreliable' + result['reason'] = 'Insufficient completed points for reliable analysis.' + result['recommended_action'] = 'fallback_due_to_surface_quality' + result['confidence'] = 0.0 + return result + + # Check quality thresholds for unreliable + failed_frac = quality.get('failed_fraction', 0.0) or 0.0 + invalid_frac = quality.get('invalid_fraction', 0.0) or 0.0 + warn_frac = quality.get('validation_warning_fraction', 0.0) or 0.0 + quality_score = quality.get('quality_score', 0.0) or 0.0 + + unreliable_reasons = [] + if failed_frac > QUALITY_FAILED_FRACTION_LIMIT: + unreliable_reasons.append(f'failed_fraction={failed_frac:.2f} > {QUALITY_FAILED_FRACTION_LIMIT}') + if invalid_frac > QUALITY_INVALID_FRACTION_LIMIT: + unreliable_reasons.append(f'invalid_fraction={invalid_frac:.2f} > {QUALITY_INVALID_FRACTION_LIMIT}') + if warn_frac > QUALITY_DISC_EDGE_FRACTION_LIMIT: + unreliable_reasons.append(f'disc_edge_fraction={warn_frac:.2f} > {QUALITY_DISC_EDGE_FRACTION_LIMIT}') + + if unreliable_reasons: + result['status'] = 'complete' + result['classification'] = 'unreliable' + result['reason'] = '; '.join(unreliable_reasons) + result['recommended_action'] = 'fallback_due_to_surface_quality' + result['confidence'] = round(max(0.0, 1.0 - quality_score), 2) + result['notes'].append('Surface quality issues prevent reliable coupling analysis.') + return result + + # Classify based on coupling + nonsep = coupling.get('nonseparability_score', 0.0) or 0.0 + cross_term = coupling.get('cross_term_strength', 0.0) or 0.0 + + is_coupled = (nonsep > COUPLING_NONSEP_THRESHOLD or cross_term > COUPLING_CROSS_TERM_THRESHOLD) + + result['status'] = 'complete' + if is_coupled: + result['classification'] = 'coupled' + result['reason'] = (f'nonseparability={nonsep:.4f} (threshold={COUPLING_NONSEP_THRESHOLD}), ' + f'cross_term={cross_term:.4f} (threshold={COUPLING_CROSS_TERM_THRESHOLD})') + result['recommended_action'] = 'retain_as_coupled_2d_surface' + result['confidence'] = round(min(1.0, nonsep / COUPLING_NONSEP_THRESHOLD), 2) + else: + result['classification'] = 'separable' + result['reason'] = (f'nonseparability={nonsep:.4f} (threshold={COUPLING_NONSEP_THRESHOLD}), ' + f'cross_term={cross_term:.4f} (threshold={COUPLING_CROSS_TERM_THRESHOLD})') + result['recommended_action'] = 'treat_as_separable_1d_like' + result['confidence'] = round(min(1.0, (COUPLING_NONSEP_THRESHOLD - nonsep) / COUPLING_NONSEP_THRESHOLD), 2) + + return result + + +def update_nd_classification(rotor_dict: dict, increment: float) -> None: + """ + Run coupling metrics, surface quality, and classification, and store all + results on the rotor's adaptive state. + + Only runs for adaptive 2D brute-force scans. + + Args: + rotor_dict (dict): A single entry from ``species.rotors_dict``. + increment (float): The scan resolution in degrees. + """ + if not is_adaptive_enabled(rotor_dict): + return + state = rotor_dict.setdefault('adaptive_scan', {}) + + # Compute coupling metrics + update_coupling_metrics(rotor_dict) + + # Compute surface quality + update_surface_quality_metrics(rotor_dict, increment) + + # Classify + classification = classify_adaptive_nd_rotor(rotor_dict, increment) + state['nd_classification'] = classification + + # Log + cls = classification.get('classification', 'unknown') + reason = classification.get('reason', '') + action = classification.get('recommended_action', '') + logger.info(f'Adaptive 2D rotor classified as "{cls}": {reason}. ' + f'Recommended: {action}.') diff --git a/arc/species/nd_scan_test.py b/arc/species/nd_scan_test.py new file mode 100644 index 0000000000..6106efb669 --- /dev/null +++ b/arc/species/nd_scan_test.py @@ -0,0 +1,1469 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +""" +This module contains unit tests of the arc.species.nd_scan module +""" + +import math +import unittest + +from arc.exceptions import SchedulerError +from arc.species.nd_scan import (decrement_running_jobs, + finalize_directed_scan_results, + fit_adaptive_surface_model, + format_dihedral_key, + generate_adaptive_candidate_points, + generate_adaptive_seed_points, + get_completed_adaptive_points, + get_continuous_scan_dihedrals, + get_pending_adaptive_points, + get_rotor_dict_by_pivots, + get_torsion_dihedral_grid, + get_adaptive_stopping_reason, + increment_continuous_scan_indices, + initialize_adaptive_scan_state, + initialize_continuous_scan_state, + is_adaptive_eligible, + is_adaptive_enabled, + is_adaptive_scan_complete, + is_continuous_scan_complete, + iter_brute_force_scan_points, + mark_scan_point_completed, + mark_scan_point_failed, + mark_scan_point_invalid, + mark_scan_points_pending, + predict_surface_values, + score_candidate_points, + select_next_adaptive_points, + should_continue_adaptive_scan, + normalize_directed_scan_energies, + record_directed_scan_point, + validate_scan_resolution, + calculate_neighbor_energy_jump, + calculate_neighbor_geometry_rmsd, + classify_neighbor_edge_continuity, + detect_branch_jump_points, + get_sampled_point_neighbors, + iter_sampled_neighbor_edges, + run_adaptive_surface_validation, + update_adaptive_validation_state, + check_periodic_edge_consistency, + extract_adaptive_2d_surface_arrays, + fit_separable_surface_proxy, + calculate_separable_fit_error, + calculate_nonseparability_score, + calculate_cross_term_strength, + calculate_low_energy_path_coupling, + compute_coupling_metrics, + compute_surface_quality_metrics, + calculate_coverage_fraction, + update_nd_classification, + COUPLING_NONSEP_THRESHOLD, + ) + + +class TestNDScan(unittest.TestCase): + """ + Contains unit tests for the nd_scan module + """ + + @classmethod + def setUpClass(cls): + """ + A method that is run before all unit tests in this class. + """ + cls.maxDiff = None + + def test_validate_scan_resolution(self): + """Test scan resolution validation.""" + # Valid resolutions + validate_scan_resolution(8.0) + validate_scan_resolution(10.0) + validate_scan_resolution(1.0) + validate_scan_resolution(120.0) + validate_scan_resolution(360.0) + + # Invalid resolutions + with self.assertRaises(SchedulerError): + validate_scan_resolution(7.0) + with self.assertRaises(SchedulerError): + validate_scan_resolution(11.0) + with self.assertRaises(SchedulerError): + validate_scan_resolution(13.0) + + def test_get_torsion_dihedral_grid(self): + """Test generating dihedral grids from xyz coordinates.""" + # Create a simple xyz dict: 4 atoms in a known geometry + # Using a methanol-like geometry (H-O-C-H dihedral) + xyz = {'symbols': ('H', 'O', 'C', 'H'), + 'isotopes': (1, 16, 12, 1), + 'coords': ((0.0, 0.0, 1.0), + (0.0, 0.0, 0.0), + (1.0, 0.0, 0.0), + (1.5, 1.0, 0.0))} + torsions = [[0, 1, 2, 3]] + increment = 120.0 # 3 + 1 = 4 points for easy verification + grid = get_torsion_dihedral_grid(xyz, torsions, increment) + + self.assertEqual(len(grid), 1) + key = tuple(torsions[0]) + self.assertIn(key, grid) + self.assertEqual(len(grid[key]), int(360 / 120) + 1) # 4 points + # All angles should be in -180..+180 range + for angle in grid[key]: + self.assertGreaterEqual(angle, -180.0) + self.assertLessEqual(angle, 180.0) + + def test_iter_brute_force_scan_points_1d(self): + """Test 1D brute-force point generation.""" + dihedrals = {(0, 1, 2, 3): [0.0, 90.0, 180.0, -90.0]} + torsions = [[0, 1, 2, 3]] + points = list(iter_brute_force_scan_points(dihedrals, torsions, diagonal=False)) + self.assertEqual(len(points), 4) + self.assertEqual(points[0], (0.0,)) + self.assertEqual(points[1], (90.0,)) + self.assertEqual(points[2], (180.0,)) + self.assertEqual(points[3], (-90.0,)) + + def test_iter_brute_force_scan_points_2d(self): + """Test 2D brute-force point generation (cartesian product).""" + dihedrals = {(0, 1, 2, 3): [0.0, 120.0, -120.0], + (4, 5, 6, 7): [10.0, 130.0, -110.0]} + torsions = [[0, 1, 2, 3], [4, 5, 6, 7]] + points = list(iter_brute_force_scan_points(dihedrals, torsions, diagonal=False)) + # 3 x 3 = 9 combinations + self.assertEqual(len(points), 9) + # First point: first angle of each torsion + self.assertEqual(points[0], (0.0, 10.0)) + # Second point: first torsion stays, second increments + self.assertEqual(points[1], (0.0, 130.0)) + # Last point + self.assertEqual(points[-1], (-120.0, -110.0)) + + def test_iter_brute_force_scan_points_2d_diagonal(self): + """Test 2D diagonal brute-force point generation.""" + dihedrals = {(0, 1, 2, 3): [0.0, 120.0, -120.0], + (4, 5, 6, 7): [10.0, 130.0, -110.0]} + torsions = [[0, 1, 2, 3], [4, 5, 6, 7]] + points = list(iter_brute_force_scan_points(dihedrals, torsions, diagonal=True)) + # diagonal: only 3 points (one per step) + self.assertEqual(len(points), 3) + self.assertEqual(points[0], (0.0, 10.0)) + self.assertEqual(points[1], (120.0, 130.0)) + self.assertEqual(points[2], (-120.0, -110.0)) + + def test_initialize_continuous_scan_state(self): + """Test continuous scan state initialization.""" + rotor_dict = { + 'torsion': [[0, 1, 2, 3], [4, 5, 6, 7]], + 'scan': [[1, 2, 3, 4], [5, 6, 7, 8]], + 'cont_indices': list(), + 'original_dihedrals': list(), + } + # Non-collinear geometry so dihedral angles are well-defined. + # scan uses 1-indexed atoms, so scan [1,2,3,4] refers to atoms 0,1,2,3 in 0-index. + xyz = {'symbols': ('H', 'O', 'C', 'H', 'N', 'C', 'O', 'H'), + 'isotopes': (1, 16, 12, 1, 14, 12, 16, 1), + 'coords': ((0.0, 1.0, 0.5), + (0.0, 0.0, 0.0), + (1.0, 0.0, 0.0), + (1.5, 1.0, 0.5), + (3.0, 1.0, 0.0), + (4.0, 0.0, 0.5), + (5.0, 0.0, 0.0), + (5.5, 1.0, 0.5))} + initialize_continuous_scan_state(rotor_dict, xyz) + + self.assertEqual(rotor_dict['cont_indices'], [0, 0]) + self.assertEqual(len(rotor_dict['original_dihedrals']), 2) + # original_dihedrals should be strings with 2 decimal places + for d in rotor_dict['original_dihedrals']: + self.assertIsInstance(d, str) + self.assertIn('.', d) + + def test_initialize_continuous_scan_state_idempotent(self): + """Test that initialization doesn't overwrite existing state.""" + rotor_dict = { + 'torsion': [[0, 1, 2, 3]], + 'scan': [[1, 2, 3, 4]], + 'cont_indices': [5], + 'original_dihedrals': ['45.00'], + } + xyz = {'symbols': ('H', 'O', 'C', 'H'), + 'isotopes': (1, 16, 12, 1), + 'coords': ((0.0, 0.0, 1.0), + (0.0, 0.0, 0.0), + (1.0, 0.0, 0.0), + (1.5, 1.0, 0.0))} + initialize_continuous_scan_state(rotor_dict, xyz) + # Should NOT overwrite existing values + self.assertEqual(rotor_dict['cont_indices'], [5]) + self.assertEqual(rotor_dict['original_dihedrals'], ['45.00']) + + def test_get_continuous_scan_dihedrals(self): + """Test computing dihedrals for a continuous scan step.""" + rotor_dict = { + 'torsion': [[0, 1, 2, 3]], + 'original_dihedrals': ['0.00'], + 'cont_indices': [3], + } + increment = 10.0 + dihedrals = get_continuous_scan_dihedrals(rotor_dict, increment) + self.assertEqual(len(dihedrals), 1) + self.assertAlmostEqual(dihedrals[0], 30.0, places=2) + + def test_get_continuous_scan_dihedrals_wrapping(self): + """Test that continuous scan dihedrals wrap around -180..+180.""" + rotor_dict = { + 'torsion': [[0, 1, 2, 3]], + 'original_dihedrals': ['170.00'], + 'cont_indices': [3], + } + increment = 10.0 + dihedrals = get_continuous_scan_dihedrals(rotor_dict, increment) + # 170 + 30 = 200, wrapped to -160 + self.assertAlmostEqual(dihedrals[0], -160.0, places=2) + + def test_is_continuous_scan_complete(self): + """Test continuous scan completion detection.""" + increment = 10.0 # 37 grid points per dimension + + # Not complete: last index not at max + rotor_dict_incomplete = { + 'torsion': [[0, 1, 2, 3], [4, 5, 6, 7]], + 'cont_indices': [36, 0], + } + self.assertFalse(is_continuous_scan_complete(rotor_dict_incomplete, increment)) + + # Complete: last index at max + rotor_dict_complete = { + 'torsion': [[0, 1, 2, 3], [4, 5, 6, 7]], + 'cont_indices': [0, 36], + } + self.assertTrue(is_continuous_scan_complete(rotor_dict_complete, increment)) + + # 1D complete + rotor_dict_1d_complete = { + 'torsion': [[0, 1, 2, 3]], + 'cont_indices': [36], + } + self.assertTrue(is_continuous_scan_complete(rotor_dict_1d_complete, increment)) + + def test_increment_continuous_scan_indices_non_diagonal(self): + """Test non-diagonal continuous scan index incrementing (odometer-style).""" + increment = 120.0 # 4 grid points, indices 0..3 + + # Simple increment of first index + rotor_dict = { + 'torsion': [[0, 1, 2, 3], [4, 5, 6, 7]], + 'cont_indices': [0, 0], + } + increment_continuous_scan_indices(rotor_dict, increment, diagonal=False) + self.assertEqual(rotor_dict['cont_indices'], [1, 0]) + + # First index at max -> rolls over to 0, second increments + rotor_dict = { + 'torsion': [[0, 1, 2, 3], [4, 5, 6, 7]], + 'cont_indices': [3, 0], + } + increment_continuous_scan_indices(rotor_dict, increment, diagonal=False) + self.assertEqual(rotor_dict['cont_indices'], [0, 1]) + + # Middle of scan + rotor_dict = { + 'torsion': [[0, 1, 2, 3], [4, 5, 6, 7]], + 'cont_indices': [2, 1], + } + increment_continuous_scan_indices(rotor_dict, increment, diagonal=False) + self.assertEqual(rotor_dict['cont_indices'], [3, 1]) + + def test_increment_continuous_scan_indices_diagonal(self): + """Test diagonal continuous scan index incrementing.""" + increment = 120.0 + + rotor_dict = { + 'torsion': [[0, 1, 2, 3], [4, 5, 6, 7]], + 'cont_indices': [0, 0], + } + increment_continuous_scan_indices(rotor_dict, increment, diagonal=True) + self.assertEqual(rotor_dict['cont_indices'], [1, 1]) + + increment_continuous_scan_indices(rotor_dict, increment, diagonal=True) + self.assertEqual(rotor_dict['cont_indices'], [2, 2]) + + def test_normalize_directed_scan_energies(self): + """Test energy normalization from a mock directed_scan dict.""" + rotor_dict = { + 'directed_scan_type': 'brute_force_opt', + 'scan': [[1, 2, 3, 4]], + 'directed_scan': { + ('0.00',): {'energy': -100.5, 'xyz': {}, 'is_isomorphic': True, 'trsh': []}, + ('120.00',): {'energy': -100.0, 'xyz': {}, 'is_isomorphic': True, 'trsh': ['some_method']}, + ('-120.00',): {'energy': -100.3, 'xyz': {}, 'is_isomorphic': True, 'trsh': []}, + }, + } + results, trshed_points = normalize_directed_scan_energies(rotor_dict) + self.assertEqual(trshed_points, 1) + self.assertEqual(results['directed_scan_type'], 'brute_force_opt') + self.assertEqual(results['scans'], [[1, 2, 3, 4]]) + # Minimum energy is -100.5, so ('0.00',) should be 0.0 + self.assertAlmostEqual(results['directed_scan'][('0.00',)]['energy'], 0.0) + self.assertAlmostEqual(results['directed_scan'][('120.00',)]['energy'], 0.5) + self.assertAlmostEqual(results['directed_scan'][('-120.00',)]['energy'], 0.2) + + def test_normalize_directed_scan_energies_with_none(self): + """Test energy normalization when some energies are None.""" + rotor_dict = { + 'directed_scan_type': 'brute_force_sp', + 'scan': [[1, 2, 3, 4]], + 'directed_scan': { + ('0.00',): {'energy': -50.0, 'xyz': {}, 'is_isomorphic': True, 'trsh': []}, + ('90.00',): {'energy': None, 'xyz': {}, 'is_isomorphic': False, 'trsh': ['method1']}, + ('180.00',): {'energy': -45.0, 'xyz': {}, 'is_isomorphic': True, 'trsh': []}, + }, + } + results, trshed_points = normalize_directed_scan_energies(rotor_dict) + self.assertEqual(trshed_points, 1) + self.assertAlmostEqual(results['directed_scan'][('0.00',)]['energy'], 0.0) + self.assertIsNone(results['directed_scan'][('90.00',)]['energy']) + self.assertAlmostEqual(results['directed_scan'][('180.00',)]['energy'], 5.0) + + def test_normalize_directed_scan_energies_2d(self): + """Test energy normalization for a 2D scan.""" + rotor_dict = { + 'directed_scan_type': 'brute_force_opt', + 'scan': [[1, 2, 3, 4], [5, 6, 7, 8]], + 'directed_scan': { + ('0.00', '0.00'): {'energy': -200.0, 'xyz': {}, 'is_isomorphic': True, 'trsh': []}, + ('0.00', '120.00'): {'energy': -195.0, 'xyz': {}, 'is_isomorphic': True, 'trsh': []}, + ('120.00', '0.00'): {'energy': -198.0, 'xyz': {}, 'is_isomorphic': True, 'trsh': []}, + ('120.00', '120.00'): {'energy': -190.0, 'xyz': {}, 'is_isomorphic': True, 'trsh': []}, + }, + } + results, trshed_points = normalize_directed_scan_energies(rotor_dict) + self.assertEqual(trshed_points, 0) + self.assertAlmostEqual(results['directed_scan'][('0.00', '0.00')]['energy'], 0.0) + self.assertAlmostEqual(results['directed_scan'][('0.00', '120.00')]['energy'], 5.0) + self.assertAlmostEqual(results['directed_scan'][('120.00', '0.00')]['energy'], 2.0) + self.assertAlmostEqual(results['directed_scan'][('120.00', '120.00')]['energy'], 10.0) + + + def test_format_dihedral_key(self): + """Test legacy string-tuple key formatting.""" + key = format_dihedral_key([180.0, -170.0]) + self.assertEqual(key, ('180.00', '-170.00')) + + key_1d = format_dihedral_key([0.0]) + self.assertEqual(key_1d, ('0.00',)) + + key_3d = format_dihedral_key([45.123, -90.456, 0.0]) + self.assertEqual(key_3d, ('45.12', '-90.46', '0.00')) + + def test_record_directed_scan_point(self): + """Test that record_directed_scan_point writes the exact legacy shape.""" + rotor_dict = {'directed_scan': {}} + record_directed_scan_point( + rotor_dict=rotor_dict, + dihedrals=[180.0, -170.0], + energy=-100.5, + xyz={'symbols': ('H',), 'coords': ((0.0, 0.0, 0.0),)}, + is_isomorphic=True, + trsh=['method1'], + ) + key = ('180.00', '-170.00') + self.assertIn(key, rotor_dict['directed_scan']) + entry = rotor_dict['directed_scan'][key] + self.assertEqual(entry['energy'], -100.5) + self.assertEqual(entry['xyz'], {'symbols': ('H',), 'coords': ((0.0, 0.0, 0.0),)}) + self.assertTrue(entry['is_isomorphic']) + self.assertEqual(entry['trsh'], ['method1']) + + def test_record_directed_scan_point_1d(self): + """Test record for a 1D scan point.""" + rotor_dict = {'directed_scan': {}} + record_directed_scan_point( + rotor_dict=rotor_dict, + dihedrals=[90.0], + energy=-50.0, + xyz=None, + is_isomorphic=False, + trsh=[], + ) + key = ('90.00',) + self.assertIn(key, rotor_dict['directed_scan']) + self.assertIsNone(rotor_dict['directed_scan'][key]['xyz']) + self.assertFalse(rotor_dict['directed_scan'][key]['is_isomorphic']) + + def test_record_directed_scan_point_overwrites(self): + """Test that recording the same point again overwrites the previous entry.""" + rotor_dict = {'directed_scan': { + ('0.00',): {'energy': -100.0, 'xyz': {}, 'is_isomorphic': True, 'trsh': []}, + }} + record_directed_scan_point( + rotor_dict=rotor_dict, + dihedrals=[0.0], + energy=-200.0, + xyz={'new': True}, + is_isomorphic=False, + trsh=['trsh1'], + ) + self.assertEqual(rotor_dict['directed_scan'][('0.00',)]['energy'], -200.0) + self.assertFalse(rotor_dict['directed_scan'][('0.00',)]['is_isomorphic']) + + def test_get_rotor_dict_by_pivots_found(self): + """Test pivot lookup when the pivots exist.""" + rotors_dict = { + 0: {'pivots': [1, 2], 'scan': [[1, 2, 3, 4]]}, + 1: {'pivots': [3, 4], 'scan': [[3, 4, 5, 6]]}, + } + match = get_rotor_dict_by_pivots(rotors_dict, [3, 4]) + self.assertIsNotNone(match) + idx, rd = match + self.assertEqual(idx, 1) + self.assertEqual(rd['pivots'], [3, 4]) + + def test_get_rotor_dict_by_pivots_not_found(self): + """Test pivot lookup when the pivots do not exist.""" + rotors_dict = { + 0: {'pivots': [1, 2], 'scan': [[1, 2, 3, 4]]}, + } + match = get_rotor_dict_by_pivots(rotors_dict, [99, 100]) + self.assertIsNone(match) + + def test_get_rotor_dict_by_pivots_nested(self): + """Test pivot lookup with nested pivots (list of lists).""" + rotors_dict = { + 0: {'pivots': [[1, 2], [3, 4]], 'scan': [[1, 2, 3, 4], [3, 4, 5, 6]]}, + } + match = get_rotor_dict_by_pivots(rotors_dict, [[1, 2], [3, 4]]) + self.assertIsNotNone(match) + idx, rd = match + self.assertEqual(idx, 0) + + def test_finalize_directed_scan_results_non_ess(self): + """Test finalize produces the same payload as normalize for non-ESS scans.""" + rotor_dict = { + 'directed_scan_type': 'brute_force_opt', + 'scan': [[1, 2, 3, 4]], + 'directed_scan': { + ('0.00',): {'energy': -100.5, 'xyz': {}, 'is_isomorphic': True, 'trsh': []}, + ('120.00',): {'energy': -100.0, 'xyz': {}, 'is_isomorphic': True, 'trsh': ['m']}, + ('-120.00',): {'energy': -100.3, 'xyz': {}, 'is_isomorphic': True, 'trsh': []}, + }, + } + results, trshed = finalize_directed_scan_results(rotor_dict) + self.assertEqual(trshed, 1) + self.assertAlmostEqual(results['directed_scan'][('0.00',)]['energy'], 0.0) + self.assertAlmostEqual(results['directed_scan'][('120.00',)]['energy'], 0.5) + self.assertEqual(results['directed_scan_type'], 'brute_force_opt') + self.assertEqual(results['scans'], [[1, 2, 3, 4]]) + + def test_finalize_directed_scan_results_ess(self): + """Test finalize delegates to parser for ESS scans.""" + mock_results = {'directed_scan_type': 'ess', 'scans': [[1, 2, 3, 4]], 'directed_scan': {}} + rotor_dict = { + 'directed_scan_type': 'ess', + 'scan_path': '/fake/path.log', + 'scan': [[1, 2, 3, 4]], + 'directed_scan': {}, + } + + def mock_parse(log_file_path): + self.assertEqual(log_file_path, '/fake/path.log') + return [mock_results] + + results, trshed = finalize_directed_scan_results(rotor_dict, parse_nd_scan_energies_func=mock_parse) + self.assertEqual(trshed, 0) + self.assertIs(results, mock_results) + + def test_finalize_directed_scan_results_ess_no_func_raises(self): + """Test that ESS finalize raises if no parser func is given.""" + rotor_dict = { + 'directed_scan_type': 'ess', + 'scan_path': '/fake/path.log', + 'scan': [[1, 2, 3, 4]], + 'directed_scan': {}, + } + with self.assertRaises(ValueError): + finalize_directed_scan_results(rotor_dict) + + def test_decrement_running_jobs(self): + """Test brute-force job counter decrement.""" + rotor_dict = {'number_of_running_jobs': 3} + self.assertFalse(decrement_running_jobs(rotor_dict)) + self.assertEqual(rotor_dict['number_of_running_jobs'], 2) + + self.assertFalse(decrement_running_jobs(rotor_dict)) + self.assertEqual(rotor_dict['number_of_running_jobs'], 1) + + self.assertTrue(decrement_running_jobs(rotor_dict)) + self.assertEqual(rotor_dict['number_of_running_jobs'], 0) + + def test_decrement_running_jobs_already_zero(self): + """Test that decrementing past zero clamps to 0 and signals done.""" + rotor_dict = {'number_of_running_jobs': 0} + result = decrement_running_jobs(rotor_dict) + self.assertTrue(result) # clamped to 0, treated as done + self.assertEqual(rotor_dict['number_of_running_jobs'], 0) + + +class TestAdaptiveNDScan(unittest.TestCase): + """ + Contains unit tests for the adaptive sparse 2D scan functionality in nd_scan module. + """ + + @classmethod + def setUpClass(cls): + """Set up test fixtures.""" + cls.maxDiff = None + # Non-collinear 8-atom geometry for 2D scans + cls.xyz = { + 'symbols': ('H', 'O', 'C', 'H', 'N', 'C', 'O', 'H'), + 'isotopes': (1, 16, 12, 1, 14, 12, 16, 1), + 'coords': ((0.0, 1.0, 0.5), + (0.0, 0.0, 0.0), + (1.0, 0.0, 0.0), + (1.5, 1.0, 0.5), + (3.0, 1.0, 0.0), + (4.0, 0.0, 0.5), + (5.0, 0.0, 0.0), + (5.5, 1.0, 0.5)), + } + + def _make_2d_rotor_dict(self, scan_type='brute_force_opt', policy='dense'): + """Helper to create a 2D rotor dict for testing.""" + rd = { + 'pivots': [[1, 2], [5, 6]], + 'top': [[0], [7]], + 'scan': [[1, 2, 3, 4], [5, 6, 7, 8]], + 'torsion': [[0, 1, 2, 3], [4, 5, 6, 7]], + 'number_of_running_jobs': 0, + 'success': None, + 'invalidation_reason': '', + 'times_dihedral_set': 0, + 'trsh_counter': 0, + 'trsh_methods': list(), + 'scan_path': '', + 'directed_scan_type': scan_type, + 'directed_scan': dict(), + 'dimensions': 2, + 'original_dihedrals': list(), + 'cont_indices': list(), + 'symmetry': None, + 'max_e': None, + } + if policy == 'adaptive': + rd['sampling_policy'] = 'adaptive' + return rd + + # -- Eligibility tests -- + + def test_is_adaptive_eligible_brute_force_2d(self): + """Test eligibility for a 2D brute_force_opt rotor.""" + rd = self._make_2d_rotor_dict('brute_force_opt') + self.assertTrue(is_adaptive_eligible(rd)) + + def test_is_adaptive_eligible_brute_force_sp_2d(self): + """Test eligibility for a 2D brute_force_sp rotor.""" + rd = self._make_2d_rotor_dict('brute_force_sp') + self.assertTrue(is_adaptive_eligible(rd)) + + def test_is_adaptive_ineligible_ess(self): + """Test ESS scans are not eligible.""" + rd = self._make_2d_rotor_dict('ess') + self.assertFalse(is_adaptive_eligible(rd)) + + def test_is_adaptive_ineligible_cont_opt(self): + """Test continuous scans are not eligible.""" + rd = self._make_2d_rotor_dict('cont_opt') + self.assertFalse(is_adaptive_eligible(rd)) + + def test_is_adaptive_ineligible_diagonal(self): + """Test diagonal brute-force scans are not eligible.""" + rd = self._make_2d_rotor_dict('brute_force_opt_diagonal') + self.assertFalse(is_adaptive_eligible(rd)) + + def test_is_adaptive_ineligible_1d(self): + """Test 1D scans are not eligible.""" + rd = self._make_2d_rotor_dict('brute_force_opt') + rd['dimensions'] = 1 + self.assertFalse(is_adaptive_eligible(rd)) + + def test_is_adaptive_enabled_default_dense(self): + """Test that adaptive is disabled by default (dense policy).""" + rd = self._make_2d_rotor_dict('brute_force_opt') + self.assertFalse(is_adaptive_enabled(rd)) + + def test_is_adaptive_enabled_with_policy(self): + """Test that adaptive is enabled when policy is 'adaptive'.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + self.assertTrue(is_adaptive_enabled(rd)) + + # -- Initialization tests -- + + def test_initialize_adaptive_scan_state(self): + """Test adaptive state initialization.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0) + state = rd['adaptive_scan'] + self.assertTrue(state['enabled']) + self.assertEqual(state['phase'], 'seed') + self.assertGreater(len(state['seed_points']), 0) + self.assertEqual(state['stopping_reason'], None) + self.assertEqual(len(state['completed_points']), 0) + + def test_initialize_adaptive_scan_state_idempotent(self): + """Test that initialization is idempotent.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0) + seeds_1 = list(rd['adaptive_scan']['seed_points']) + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0) + seeds_2 = rd['adaptive_scan']['seed_points'] + self.assertEqual(seeds_1, seeds_2) + + # -- Seed generation tests -- + + def test_seed_points_include_origin(self): + """Test that seed points include the current-geometry point.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + seeds = generate_adaptive_seed_points(rd, self.xyz, increment=120.0) + self.assertGreater(len(seeds), 0) + # All seeds should be 2-element tuples + for s in seeds: + self.assertEqual(len(s), 2) + + def test_seed_points_no_duplicates(self): + """Test that seed points are deduplicated.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + seeds = generate_adaptive_seed_points(rd, self.xyz, increment=10.0) + keys = [tuple(f'{a:.2f}' for a in s) for s in seeds] + self.assertEqual(len(keys), len(set(keys))) + + def test_seed_points_reasonable_count(self): + """Test that seed count is between expected bounds for 8-degree resolution.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + seeds = generate_adaptive_seed_points(rd, self.xyz, increment=8.0) + # With increment=8: coarse grid ~16x16=256, plus 1D cuts ~2x46=92, minus overlaps + # Should be well under the full grid of 46*46=2116 but substantial + self.assertGreater(len(seeds), 50) + self.assertLess(len(seeds), 500) + + # -- Bookkeeping tests -- + + def test_mark_scan_points_pending(self): + """Test marking points as pending.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0) + points = [[0.0, 0.0], [120.0, 120.0]] + mark_scan_points_pending(rd, points) + self.assertEqual(len(get_pending_adaptive_points(rd)), 2) + + def test_mark_scan_points_pending_no_duplicates(self): + """Test that pending doesn't get duplicates.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0) + points = [[0.0, 0.0]] + mark_scan_points_pending(rd, points) + mark_scan_points_pending(rd, points) + self.assertEqual(len(get_pending_adaptive_points(rd)), 1) + + def test_mark_scan_point_completed(self): + """Test completing a point moves it from pending and writes legacy.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0) + mark_scan_points_pending(rd, [[45.0, -90.0]]) + mark_scan_point_completed(rd, [45.0, -90.0], energy=-100.0, xyz={}, + is_isomorphic=True, trsh=[]) + self.assertEqual(len(get_pending_adaptive_points(rd)), 0) + self.assertEqual(len(get_completed_adaptive_points(rd)), 1) + # Also in legacy directed_scan + self.assertIn(('45.00', '-90.00'), rd['directed_scan']) + entry = rd['directed_scan'][('45.00', '-90.00')] + self.assertEqual(entry['energy'], -100.0) + + def test_mark_scan_point_failed(self): + """Test failing a point removes it from pending.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0) + mark_scan_points_pending(rd, [[45.0, -90.0]]) + mark_scan_point_failed(rd, [45.0, -90.0]) + self.assertEqual(len(get_pending_adaptive_points(rd)), 0) + self.assertEqual(len(rd['adaptive_scan']['failed_points']), 1) + + def test_mark_scan_point_invalid(self): + """Test invalidating a point.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0) + mark_scan_points_pending(rd, [[45.0, -90.0]]) + mark_scan_point_invalid(rd, [45.0, -90.0]) + self.assertEqual(len(get_pending_adaptive_points(rd)), 0) + self.assertEqual(len(rd['adaptive_scan']['invalid_points']), 1) + + # -- Surrogate / model tests -- + + def test_fit_adaptive_surface_model(self): + """Test fitting a surface model from completed points.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0) + # Simulate some completed points + for phi0, phi1, e in [(0.0, 0.0, -100.0), (120.0, 0.0, -95.0), (0.0, 120.0, -90.0)]: + mark_scan_point_completed(rd, [phi0, phi1], energy=e, xyz={}, + is_isomorphic=True, trsh=[]) + model = fit_adaptive_surface_model(rd) + self.assertEqual(model['type'], 'idw') + self.assertEqual(len(model['centers']), 3) + self.assertEqual(len(model['values']), 3) + + def test_predict_surface_values(self): + """Test surface prediction at query points.""" + model = { + 'type': 'idw', + 'centers': [[0.0, 0.0], [120.0, 0.0]], + 'values': [0.0, 10.0], + 'length_scale': 30.0, + } + preds = predict_surface_values(model, [[0.0, 0.0], [120.0, 0.0], [60.0, 0.0]]) + # At exact centers, should return close to center values + self.assertAlmostEqual(preds[0], 0.0, places=1) + self.assertAlmostEqual(preds[1], 10.0, places=1) + # Midpoint should be somewhere between + self.assertGreater(preds[2], 0.0) + self.assertLess(preds[2], 10.0) + + def test_predict_surface_values_empty_model(self): + """Test prediction with empty model returns None.""" + model = {'centers': [], 'values': [], 'length_scale': 30.0} + preds = predict_surface_values(model, [[0.0, 0.0]]) + self.assertIsNone(preds[0]) + + def test_score_candidate_points(self): + """Test scoring prefers distant points.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0) + # Complete one point at origin + mark_scan_point_completed(rd, [0.0, 0.0], energy=-100.0, xyz={}, + is_isomorphic=True, trsh=[]) + fit_adaptive_surface_model(rd) + # Score: a far point vs a near point + scores = score_candidate_points(rd, [[180.0, 180.0], [1.0, 1.0]]) + self.assertGreater(scores[0], scores[1]) + + # -- Candidate generation tests -- + + def test_generate_adaptive_candidate_points(self): + """Test candidate generation excludes visited points.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + rd['original_dihedrals'] = ['0.00', '0.00'] + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0) + # Complete all seed points + for s in list(rd['adaptive_scan']['seed_points']): + mark_scan_point_completed(rd, s, energy=-100.0, xyz={}, + is_isomorphic=True, trsh=[]) + candidates = generate_adaptive_candidate_points(rd, increment=120.0) + # No candidate should be in completed + completed_keys = {tuple(f'{a:.2f}' for a in c) + for c in get_completed_adaptive_points(rd)} + for c in candidates: + key = tuple(f'{a:.2f}' for a in c) + self.assertNotIn(key, completed_keys) + + # -- Selection tests -- + + def test_select_next_adaptive_points_seed_phase(self): + """Test that seed phase returns seed points.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0, batch_size=5) + points = select_next_adaptive_points(rd, increment=120.0) + self.assertGreater(len(points), 0) + self.assertLessEqual(len(points), 5) + + def test_select_next_adaptive_points_no_resubmit(self): + """Test that completed points are never resubmitted.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + rd['original_dihedrals'] = ['0.00', '0.00'] + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0, batch_size=100) + # Get all seeds + batch1 = select_next_adaptive_points(rd, increment=120.0, batch_size=1000) + mark_scan_points_pending(rd, batch1) + for p in batch1: + mark_scan_point_completed(rd, p, energy=-100.0, xyz={}, + is_isomorphic=True, trsh=[]) + # Now get adaptive batch + batch2 = select_next_adaptive_points(rd, increment=120.0) + batch1_keys = {tuple(f'{a:.2f}' for a in p) for p in batch1} + for p in batch2: + key = tuple(f'{a:.2f}' for a in p) + self.assertNotIn(key, batch1_keys) + + # -- Stopping tests -- + + def test_stopping_max_points(self): + """Test stopping when max_points reached.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0, max_points=3) + # Simulate 3 completed points + for i, (a, b) in enumerate([(0.0, 0.0), (120.0, 0.0), (0.0, 120.0)]): + mark_scan_point_completed(rd, [a, b], energy=float(-100 + i), + xyz={}, is_isomorphic=True, trsh=[]) + reason = get_adaptive_stopping_reason(rd, increment=120.0) + self.assertEqual(reason, 'max_points_reached') + + def test_stopping_grid_exhausted(self): + """Test stopping when all grid points are visited.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + rd['original_dihedrals'] = ['0.00', '0.00'] + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0, + max_points=10000) + # Complete all 4x4=16 grid points + n = int(360 / 120.0) + 1 + for i in range(n): + for j in range(n): + a0 = round(((i * 120.0) + 180.0) % 360.0 - 180.0, 2) + a1 = round(((j * 120.0) + 180.0) % 360.0 - 180.0, 2) + mark_scan_point_completed(rd, [a0, a1], energy=-100.0, + xyz={}, is_isomorphic=True, trsh=[]) + self.assertFalse(should_continue_adaptive_scan(rd, increment=120.0)) + + def test_is_adaptive_scan_complete_with_pending(self): + """Test that scan is not complete when pending points exist.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0, max_points=2) + mark_scan_points_pending(rd, [[0.0, 0.0]]) + mark_scan_point_completed(rd, [120.0, 0.0], energy=-100.0, xyz={}, + is_isomorphic=True, trsh=[]) + # max_points=2, completed+pending=2, but pending > 0 + self.assertFalse(is_adaptive_scan_complete(rd, increment=120.0)) + + # -- Restart / serialization tests -- + + def test_adaptive_state_is_yaml_serializable(self): + """Test that adaptive state contains only YAML-safe types.""" + import json + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0) + # Simulate some activity + seeds = rd['adaptive_scan']['seed_points'][:3] + mark_scan_points_pending(rd, seeds) + mark_scan_point_completed(rd, seeds[0], energy=-100.0, xyz={}, + is_isomorphic=True, trsh=[]) + mark_scan_point_failed(rd, seeds[1]) + # json.dumps will raise if not serializable + json.dumps(rd['adaptive_scan']) + json.dumps(rd['sampling_policy']) + + def test_dense_rotor_unchanged_by_adaptive_code(self): + """Test that a dense rotor dict has no adaptive_scan key.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='dense') + self.assertNotIn('adaptive_scan', rd) + self.assertFalse(is_adaptive_enabled(rd)) + + def test_directed_scan_type_values_unchanged(self): + """Verify that no new directed_scan_type values are introduced.""" + valid_types = { + 'ess', 'brute_force_sp', 'brute_force_opt', 'cont_opt', + 'brute_force_sp_diagonal', 'brute_force_opt_diagonal', 'cont_opt_diagonal', + } + # Check eligibility function doesn't accept anything outside legacy types + for dst in valid_types: + rd = self._make_2d_rotor_dict(dst) + # Just verify it doesn't crash + is_adaptive_eligible(rd) + # A hypothetical new type should not be eligible + rd = self._make_2d_rotor_dict('adaptive_brute_force') + self.assertFalse(is_adaptive_eligible(rd)) + + def test_finalize_adds_sparse_metadata_for_adaptive(self): + """Test that finalize_directed_scan_results adds sparse metadata for adaptive scans.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='adaptive') + initialize_adaptive_scan_state(rd, self.xyz, increment=120.0) + # Complete some points + for phi0, phi1, e in [(0.0, 0.0, -100.0), (120.0, 0.0, -95.0), (0.0, 120.0, -90.0)]: + mark_scan_point_completed(rd, [phi0, phi1], energy=e, xyz={}, + is_isomorphic=True, trsh=[]) + mark_scan_point_failed(rd, [60.0, 60.0]) + results, trshed = finalize_directed_scan_results(rd) + self.assertEqual(results['sampling_policy'], 'adaptive') + self.assertIn('adaptive_scan_summary', results) + summary = results['adaptive_scan_summary'] + self.assertEqual(summary['completed_count'], 3) + self.assertEqual(summary['failed_count'], 1) + self.assertEqual(len(summary['failed_points']), 1) + + def test_finalize_no_sparse_metadata_for_dense(self): + """Test that finalize_directed_scan_results has no sparse metadata for dense scans.""" + rd = self._make_2d_rotor_dict('brute_force_opt', policy='dense') + # Add some scan points directly + rd['directed_scan'] = { + ('0.00', '0.00'): {'energy': -100.0, 'xyz': {}, 'is_isomorphic': True, 'trsh': []}, + ('120.00', '0.00'): {'energy': -95.0, 'xyz': {}, 'is_isomorphic': True, 'trsh': []}, + } + results, _ = finalize_directed_scan_results(rd) + self.assertNotIn('sampling_policy', results) + self.assertNotIn('adaptive_scan_summary', results) + + +class TestSurfaceValidation(unittest.TestCase): + """ + Contains unit tests for adaptive 2D surface validation in nd_scan module. + """ + + @classmethod + def setUpClass(cls): + """Set up test fixtures.""" + cls.maxDiff = None + + def _make_validated_rotor(self, increment=120.0, energy_fn=None): + """ + Helper: create a 2D adaptive rotor with completed points on a 120-degree grid + and mock xyz data. ``energy_fn(phi0, phi1) -> float`` sets the energy. + """ + if energy_fn is None: + energy_fn = lambda a, b: abs(a) + abs(b) # smooth + rd = { + 'pivots': [[1, 2], [5, 6]], + 'top': [[0], [7]], + 'scan': [[1, 2, 3, 4], [5, 6, 7, 8]], + 'torsion': [[0, 1, 2, 3], [4, 5, 6, 7]], + 'number_of_running_jobs': 0, + 'success': None, + 'invalidation_reason': '', + 'times_dihedral_set': 0, + 'trsh_counter': 0, + 'trsh_methods': [], + 'scan_path': '', + 'directed_scan_type': 'brute_force_opt', + 'directed_scan': {}, + 'dimensions': 2, + 'original_dihedrals': ['0.00', '0.00'], + 'cont_indices': [], + 'symmetry': None, + 'max_e': None, + 'sampling_policy': 'adaptive', + 'adaptive_scan': { + 'enabled': True, + 'phase': 'complete', + 'batch_size': 10, + 'candidate_points': [], + 'pending_points': [], + 'completed_points': [], + 'failed_points': [], + 'invalid_points': [], + 'seed_points': [], + 'selected_points_history': [], + 'stopping_reason': 'max_points_reached', + 'max_points': 200, + 'min_points': 20, + 'fit_metadata': {}, + 'surface_model': {}, + }, + } + # Build a simple mock xyz (4 atoms, tetrahedron-like) + base_coords = ((0.0, 0.0, 0.0), (1.0, 0.0, 0.0), + (0.5, 0.87, 0.0), (0.5, 0.29, 0.82)) + base_xyz = { + 'symbols': ('C', 'H', 'H', 'H'), + 'isotopes': (12, 1, 1, 1), + 'coords': base_coords, + } + n = int(360 / increment) + 1 + for i in range(n): + for j in range(n): + a0 = round((i * increment + 180.0) % 360.0 - 180.0, 2) + a1 = round((j * increment + 180.0) % 360.0 - 180.0, 2) + key = (f'{a0:.2f}', f'{a1:.2f}') + e = energy_fn(a0, a1) + rd['directed_scan'][key] = { + 'energy': e, + 'xyz': base_xyz, + 'is_isomorphic': True, + 'trsh': [], + } + rd['adaptive_scan']['completed_points'].append([a0, a1]) + return rd + + # -- Neighbor helpers -- + + def test_get_sampled_point_neighbors(self): + """Test finding neighbors of a sampled point.""" + rd = self._make_validated_rotor(increment=120.0) + neighbors = get_sampled_point_neighbors(rd, [0.0, 0.0], increment=120.0) + # 0.0 has neighbors at ±120 in each dimension + self.assertGreaterEqual(len(neighbors), 2) + + def test_iter_sampled_neighbor_edges(self): + """Test iterating unique neighbor edges.""" + rd = self._make_validated_rotor(increment=120.0) + edges = list(iter_sampled_neighbor_edges(rd, increment=120.0)) + self.assertGreater(len(edges), 0) + # Each edge should be unique + edge_keys = set() + for a, b in edges: + key = tuple(sorted([ + (f'{a[0]:.2f}', f'{a[1]:.2f}'), + (f'{b[0]:.2f}', f'{b[1]:.2f}') + ])) + self.assertNotIn(key, edge_keys) + edge_keys.add(key) + + # -- Energy jump -- + + def test_calculate_neighbor_energy_jump(self): + """Test energy jump calculation.""" + rd = self._make_validated_rotor(increment=120.0) + jump = calculate_neighbor_energy_jump(rd, [0.0, 0.0], [120.0, 0.0]) + self.assertIsNotNone(jump) + self.assertIsInstance(jump, float) + self.assertGreaterEqual(jump, 0.0) + + def test_calculate_neighbor_energy_jump_missing(self): + """Test energy jump returns None for missing point.""" + rd = self._make_validated_rotor(increment=120.0) + jump = calculate_neighbor_energy_jump(rd, [0.0, 0.0], [999.0, 999.0]) + self.assertIsNone(jump) + + # -- Geometry RMSD -- + + def test_calculate_neighbor_geometry_rmsd_same_xyz(self): + """Test RMSD is ~0 for identical geometries.""" + rd = self._make_validated_rotor(increment=120.0) + rmsd = calculate_neighbor_geometry_rmsd(rd, [0.0, 0.0], [120.0, 0.0]) + # Same base_xyz for all points, so RMSD should be ~0 + self.assertIsNotNone(rmsd) + self.assertAlmostEqual(rmsd, 0.0, places=4) + + def test_calculate_neighbor_geometry_rmsd_missing(self): + """Test RMSD returns None for missing geometry.""" + rd = self._make_validated_rotor(increment=120.0) + # Overwrite one point's xyz to None + rd['directed_scan'][('0.00', '0.00')]['xyz'] = None + rmsd = calculate_neighbor_geometry_rmsd(rd, [0.0, 0.0], [120.0, 0.0]) + self.assertIsNone(rmsd) + + # -- Edge classification -- + + def test_classify_continuous_edge(self): + """Test that a smooth edge is classified as continuous.""" + rd = self._make_validated_rotor(increment=120.0, + energy_fn=lambda a, b: 0.1 * (a + b)) + result = classify_neighbor_edge_continuity(rd, [0.0, 0.0], [120.0, 0.0]) + self.assertTrue(result['continuous']) + self.assertEqual(result['reasons'], []) + + def test_classify_discontinuous_edge_energy(self): + """Test that a huge energy jump flags the edge as discontinuous.""" + rd = self._make_validated_rotor(increment=120.0) + # Inject a massive energy at one point + rd['directed_scan'][('120.00', '0.00')]['energy'] = 9999.0 + result = classify_neighbor_edge_continuity(rd, [0.0, 0.0], [120.0, 0.0], + energy_threshold=10.0) + self.assertFalse(result['continuous']) + self.assertTrue(any('energy_jump' in r for r in result['reasons'])) + + # -- Periodic consistency -- + + def test_check_periodic_edge_consistency_same_geometry(self): + """Test periodic check with identical geometries is consistent.""" + rd = self._make_validated_rotor(increment=120.0) + # -120 and 120 differ by 240 on the number line but are neighbors via wrap + result = check_periodic_edge_consistency(rd, [-120.0, 0.0], [120.0, 0.0]) + self.assertTrue(result['consistent']) + + # -- Branch-jump detection -- + + def test_detect_branch_jump_points_smooth(self): + """Test no branch jumps on a genuinely smooth surface (constant energy).""" + rd = self._make_validated_rotor(increment=120.0, + energy_fn=lambda a, b: 5.0) + flagged = detect_branch_jump_points(rd, increment=120.0, energy_threshold=50.0) + self.assertEqual(flagged, []) + + def test_detect_branch_jump_points_spike(self): + """Test that a spike point surrounded by smooth neighbors gets flagged.""" + rd = self._make_validated_rotor(increment=120.0, + energy_fn=lambda a, b: 0.0) + # Inject a huge spike at one point + rd['directed_scan'][('0.00', '0.00')]['energy'] = 999.0 + flagged = detect_branch_jump_points(rd, increment=120.0, + energy_threshold=10.0, + min_suspicious_edges=2) + flagged_keys = {(f'{p[0]:.2f}', f'{p[1]:.2f}') for p in flagged} + self.assertIn(('0.00', '0.00'), flagged_keys) + + # -- Full validation orchestration -- + + def test_run_adaptive_surface_validation_smooth(self): + """Test validation on a smooth surface.""" + rd = self._make_validated_rotor(increment=120.0, + energy_fn=lambda a, b: 0.01 * abs(a + b)) + val = run_adaptive_surface_validation(rd, increment=120.0) + self.assertEqual(val['status'], 'complete') + self.assertGreater(val['neighbor_edges_checked'], 0) + self.assertEqual(len(val['discontinuous_edges']), 0) + self.assertEqual(len(val['branch_jump_points']), 0) + self.assertIn('Surface passed all continuity checks.', val['notes']) + + def test_run_adaptive_surface_validation_with_spike(self): + """Test validation detects a discontinuity.""" + rd = self._make_validated_rotor(increment=120.0, + energy_fn=lambda a, b: 0.0) + rd['directed_scan'][('0.00', '0.00')]['energy'] = 999.0 + val = run_adaptive_surface_validation(rd, increment=120.0) + self.assertEqual(val['status'], 'complete') + self.assertGreater(len(val['discontinuous_edges']), 0) + + def test_run_adaptive_surface_validation_empty(self): + """Test validation with no edges.""" + rd = self._make_validated_rotor(increment=120.0) + # Clear all but one point + keys = list(rd['directed_scan'].keys()) + for k in keys[1:]: + del rd['directed_scan'][k] + val = run_adaptive_surface_validation(rd, increment=120.0) + self.assertEqual(val['status'], 'no_edges') + + # -- Serializability -- + + def test_validation_state_yaml_safe(self): + """Test that validation state is JSON/YAML-serializable.""" + import json + rd = self._make_validated_rotor(increment=120.0) + val = run_adaptive_surface_validation(rd, increment=120.0) + json.dumps(val) # should not raise + + # -- Integration -- + + def test_update_adaptive_validation_state(self): + """Test that update writes validation into rotor state.""" + rd = self._make_validated_rotor(increment=120.0) + update_adaptive_validation_state(rd, increment=120.0) + self.assertIn('validation', rd['adaptive_scan']) + self.assertEqual(rd['adaptive_scan']['validation']['status'], 'complete') + + def test_update_skips_dense(self): + """Test that update does nothing for dense scans.""" + rd = self._make_validated_rotor(increment=120.0) + rd['sampling_policy'] = 'dense' + update_adaptive_validation_state(rd, increment=120.0) + self.assertNotIn('validation', rd['adaptive_scan']) + + def test_finalize_includes_validation_summary(self): + """Test that finalize includes validation_summary for adaptive scans.""" + rd = self._make_validated_rotor(increment=120.0, + energy_fn=lambda a, b: 0.0) + results, _ = finalize_directed_scan_results(rd, increment=120.0) + self.assertIn('validation_summary', results) + self.assertEqual(results['validation_summary']['status'], 'complete') + + def test_finalize_no_validation_for_dense(self): + """Test that finalize omits validation_summary for dense scans.""" + rd = self._make_validated_rotor(increment=120.0) + rd['sampling_policy'] = 'dense' + results, _ = finalize_directed_scan_results(rd, increment=120.0) + self.assertNotIn('validation_summary', results) + + def test_edge_classification_output_shape(self): + """Test that classify_neighbor_edge_continuity returns all expected keys.""" + rd = self._make_validated_rotor(increment=120.0) + result = classify_neighbor_edge_continuity(rd, [0.0, 0.0], [120.0, 0.0]) + self.assertIn('continuous', result) + self.assertIn('energy_jump', result) + self.assertIn('geometry_rmsd', result) + self.assertIn('reasons', result) + self.assertIsInstance(result['reasons'], list) + + +class TestCouplingAndClassification(unittest.TestCase): + """ + Tests for coupling metrics, surface quality, and ND rotor classification. + Uses synthetic 2D surfaces with known separable/coupled structure. + """ + + def _make_rotor_with_surface(self, energy_fn, increment=120.0, n_failed=0, n_invalid=0): + """ + Helper: create a fully populated adaptive 2D rotor dict with a given energy function. + energy_fn(phi0_deg, phi1_deg) -> energy in kJ/mol (already normalized, min ~0). + """ + rd = { + 'pivots': [[1, 2], [5, 6]], + 'top': [[0], [7]], + 'scan': [[1, 2, 3, 4], [5, 6, 7, 8]], + 'torsion': [[0, 1, 2, 3], [4, 5, 6, 7]], + 'number_of_running_jobs': 0, + 'success': None, + 'invalidation_reason': '', + 'times_dihedral_set': 0, + 'trsh_counter': 0, + 'trsh_methods': [], + 'scan_path': '', + 'directed_scan_type': 'brute_force_opt', + 'directed_scan': {}, + 'dimensions': 2, + 'original_dihedrals': ['0.00', '0.00'], + 'cont_indices': [], + 'symmetry': None, + 'max_e': None, + 'sampling_policy': 'adaptive', + 'adaptive_scan': { + 'enabled': True, + 'phase': 'complete', + 'batch_size': 10, + 'candidate_points': [], + 'pending_points': [], + 'completed_points': [], + 'failed_points': [], + 'invalid_points': [], + 'seed_points': [], + 'selected_points_history': [], + 'stopping_reason': 'max_points_reached', + 'max_points': 200, + 'min_points': 20, + 'fit_metadata': {}, + 'surface_model': {}, + }, + } + n = int(360 / increment) + 1 + count = 0 + for i in range(n): + for j in range(n): + a0 = round((i * increment + 180.0) % 360.0 - 180.0, 2) + a1 = round((j * increment + 180.0) % 360.0 - 180.0, 2) + if count < n_failed: + rd['adaptive_scan']['failed_points'].append([a0, a1]) + count += 1 + continue + if count < n_failed + n_invalid: + rd['adaptive_scan']['invalid_points'].append([a0, a1]) + count += 1 + continue + key = (f'{a0:.2f}', f'{a1:.2f}') + e = energy_fn(a0, a1) + rd['directed_scan'][key] = { + 'energy': e, 'xyz': {}, 'is_isomorphic': True, 'trsh': [], + } + rd['adaptive_scan']['completed_points'].append([a0, a1]) + count += 1 + return rd + + # --- Data extraction --- + + def test_extract_surface_arrays(self): + """Test surface array extraction.""" + rd = self._make_rotor_with_surface(lambda a, b: abs(a) + abs(b), increment=120.0) + data = extract_adaptive_2d_surface_arrays(rd) + self.assertGreater(data['n_points'], 0) + self.assertEqual(len(data['phi0']), data['n_points']) + self.assertEqual(len(data['energy']), data['n_points']) + + # --- Separable fit --- + + def test_separable_fit_on_separable_surface(self): + """A purely separable surface should have near-zero fit error.""" + rd = self._make_rotor_with_surface( + lambda a, b: 10.0 * (1 - math.cos(math.radians(a))) + 5.0 * (1 - math.cos(math.radians(b))), + increment=60.0) + data = extract_adaptive_2d_surface_arrays(rd) + sep_fit = fit_separable_surface_proxy(data) + error = calculate_separable_fit_error(data, sep_fit) + self.assertLess(error, 0.05, 'Separable surface should have small fit error') + + def test_separable_fit_on_coupled_surface(self): + """A coupled surface should have larger fit error.""" + rd = self._make_rotor_with_surface( + lambda a, b: 10.0 * math.cos(math.radians(a - b)), + increment=60.0) + data = extract_adaptive_2d_surface_arrays(rd) + sep_fit = fit_separable_surface_proxy(data) + error = calculate_separable_fit_error(data, sep_fit) + self.assertGreater(error, 0.05, 'Coupled surface should have larger fit error') + + # --- Nonseparability score --- + + def test_nonseparability_separable(self): + """Separable surface should have low nonseparability score.""" + rd = self._make_rotor_with_surface( + lambda a, b: 10.0 * (1 - math.cos(math.radians(a))) + 5.0 * (1 - math.cos(math.radians(b))), + increment=60.0) + data = extract_adaptive_2d_surface_arrays(rd) + sep_fit = fit_separable_surface_proxy(data) + score = calculate_nonseparability_score(data, sep_fit) + self.assertLess(score, COUPLING_NONSEP_THRESHOLD) + + def test_nonseparability_coupled(self): + """Coupled surface should have high nonseparability score.""" + rd = self._make_rotor_with_surface( + lambda a, b: 10.0 * math.cos(math.radians(a - b)), + increment=60.0) + data = extract_adaptive_2d_surface_arrays(rd) + sep_fit = fit_separable_surface_proxy(data) + score = calculate_nonseparability_score(data, sep_fit) + self.assertGreater(score, COUPLING_NONSEP_THRESHOLD) + + # --- Cross-term strength --- + + def test_cross_term_separable(self): + """Separable surface should have low cross-term strength.""" + rd = self._make_rotor_with_surface( + lambda a, b: abs(a) / 180.0 * 10.0 + abs(b) / 180.0 * 5.0, + increment=60.0) + data = extract_adaptive_2d_surface_arrays(rd) + sep_fit = fit_separable_surface_proxy(data) + ct = calculate_cross_term_strength(data, sep_fit) + self.assertLess(ct, 0.15) + + # --- Low-energy-path coupling --- + + def test_low_energy_path_separable(self): + """For a separable surface, low-energy path should show low correlation.""" + rd = self._make_rotor_with_surface( + lambda a, b: 10.0 * (1 - math.cos(math.radians(a))) + 5.0 * (1 - math.cos(math.radians(b))), + increment=30.0) + data = extract_adaptive_2d_surface_arrays(rd) + coupling = calculate_low_energy_path_coupling(data) + # For a truly separable surface the low-energy ridge is at phi0~0 regardless of phi1 + # so sin/cos correlation should be low + self.assertLess(coupling, 0.7) + + # --- Complete coupling metrics --- + + def test_compute_coupling_metrics_separable(self): + """Coupling metrics for a separable surface.""" + rd = self._make_rotor_with_surface( + lambda a, b: 10.0 * (1 - math.cos(math.radians(a))) + 5.0 * (1 - math.cos(math.radians(b))), + increment=60.0) + metrics = compute_coupling_metrics(rd) + self.assertEqual(metrics['status'], 'complete') + self.assertLess(metrics['nonseparability_score'], COUPLING_NONSEP_THRESHOLD) + + def test_compute_coupling_metrics_insufficient(self): + """Coupling metrics with too few points.""" + rd = self._make_rotor_with_surface(lambda a, b: 0.0, increment=120.0) + # Only 16 points but some might be enough. Let's clear most. + keys_to_remove = list(rd['directed_scan'].keys())[5:] + for k in keys_to_remove: + del rd['directed_scan'][k] + rd['adaptive_scan']['completed_points'] = rd['adaptive_scan']['completed_points'][:5] + metrics = compute_coupling_metrics(rd) + self.assertEqual(metrics['status'], 'insufficient_data') + + # --- Surface quality --- + + def test_surface_quality_good(self): + """Quality metrics for a clean scan.""" + rd = self._make_rotor_with_surface(lambda a, b: 0.0, increment=120.0) + # Run validation first (needed for warning fraction) + update_adaptive_validation_state(rd, increment=120.0) + metrics = compute_surface_quality_metrics(rd, increment=120.0) + self.assertEqual(metrics['status'], 'complete') + self.assertGreater(metrics['quality_score'], 0.5) + self.assertAlmostEqual(metrics['failed_fraction'], 0.0) + + def test_surface_quality_many_failures(self): + """Quality metrics with many failed points.""" + rd = self._make_rotor_with_surface(lambda a, b: 0.0, increment=120.0, n_failed=6) + metrics = compute_surface_quality_metrics(rd, increment=120.0) + self.assertGreater(metrics['failed_fraction'], 0.1) + + def test_coverage_fraction(self): + """Test coverage calculation.""" + rd = self._make_rotor_with_surface(lambda a, b: 0.0, increment=120.0) + cov = calculate_coverage_fraction(rd, increment=120.0) + # Grid formula gives (360/120+1)^2 = 16, but -180 and 180 share a key + # after normalization, so 9 unique keys out of 16 grid points. + self.assertGreater(cov, 0.5) + self.assertLessEqual(cov, 1.0) + + # --- ND classification --- + + def test_classify_separable(self): + """Clean separable surface should be classified as separable.""" + rd = self._make_rotor_with_surface( + lambda a, b: 10.0 * (1 - math.cos(math.radians(a))) + 5.0 * (1 - math.cos(math.radians(b))), + increment=60.0) + update_adaptive_validation_state(rd, increment=60.0) + update_nd_classification(rd, increment=60.0) + cls = rd['adaptive_scan']['nd_classification'] + self.assertEqual(cls['classification'], 'separable') + self.assertEqual(cls['recommended_action'], 'treat_as_separable_1d_like') + + def test_classify_coupled(self): + """Clean coupled surface should be classified as coupled.""" + rd = self._make_rotor_with_surface( + lambda a, b: 10.0 * math.cos(math.radians(a - b)), + increment=60.0) + update_adaptive_validation_state(rd, increment=60.0) + update_nd_classification(rd, increment=60.0) + cls = rd['adaptive_scan']['nd_classification'] + self.assertEqual(cls['classification'], 'coupled') + self.assertEqual(cls['recommended_action'], 'retain_as_coupled_2d_surface') + + def test_classify_unreliable_many_failures(self): + """Surface with many failures should be classified as unreliable.""" + rd = self._make_rotor_with_surface(lambda a, b: 0.0, increment=60.0, n_failed=15) + update_adaptive_validation_state(rd, increment=60.0) + update_nd_classification(rd, increment=60.0) + cls = rd['adaptive_scan']['nd_classification'] + self.assertEqual(cls['classification'], 'unreliable') + self.assertEqual(cls['recommended_action'], 'fallback_due_to_surface_quality') + + def test_classify_unreliable_insufficient_data(self): + """Very few points should be classified as unreliable.""" + rd = self._make_rotor_with_surface(lambda a, b: 0.0, increment=120.0) + # Remove most points + keys = list(rd['directed_scan'].keys())[3:] + for k in keys: + del rd['directed_scan'][k] + rd['adaptive_scan']['completed_points'] = rd['adaptive_scan']['completed_points'][:3] + update_adaptive_validation_state(rd, increment=120.0) + update_nd_classification(rd, increment=120.0) + cls = rd['adaptive_scan']['nd_classification'] + self.assertEqual(cls['classification'], 'unreliable') + + # --- Dense unchanged --- + + def test_classify_skips_dense(self): + """update_nd_classification should do nothing for dense scans.""" + rd = self._make_rotor_with_surface(lambda a, b: 0.0, increment=120.0) + rd['sampling_policy'] = 'dense' + update_nd_classification(rd, increment=120.0) + self.assertNotIn('nd_classification', rd.get('adaptive_scan', {})) + + # --- Finalization integration --- + + def test_finalize_includes_classification_summary(self): + """Finalization should include classification summary for adaptive scans.""" + rd = self._make_rotor_with_surface( + lambda a, b: 10.0 * (1 - math.cos(math.radians(a))), + increment=60.0) + results, _ = finalize_directed_scan_results(rd, increment=60.0) + self.assertIn('classification_summary', results) + self.assertIn('coupling_summary', results) + self.assertIn('surface_quality_summary', results) + self.assertIsNotNone(results['classification_summary']['classification']) + + def test_finalize_no_classification_for_dense(self): + """Finalization omits classification for dense scans.""" + rd = self._make_rotor_with_surface(lambda a, b: 0.0, increment=120.0) + rd['sampling_policy'] = 'dense' + results, _ = finalize_directed_scan_results(rd, increment=120.0) + self.assertNotIn('classification_summary', results) + self.assertNotIn('coupling_summary', results) + + # --- Serializability --- + + def test_classification_metadata_yaml_safe(self): + """All new metadata should be JSON/YAML-serializable.""" + import json + rd = self._make_rotor_with_surface( + lambda a, b: 10.0 * math.cos(math.radians(a - b)), + increment=60.0) + update_adaptive_validation_state(rd, increment=60.0) + update_nd_classification(rd, increment=60.0) + state = rd['adaptive_scan'] + json.dumps(state.get('coupling_metrics', {})) + json.dumps(state.get('surface_quality', {})) + json.dumps(state.get('nd_classification', {})) + + +if __name__ == '__main__': + unittest.main() From 484ddba8a0ad2df54b4c74b0e796fc4c27c60634 Mon Sep 17 00:00:00 2001 From: Alon Grinberg Dana Date: Thu, 2 Apr 2026 00:32:58 +0300 Subject: [PATCH 2/4] Added directed scan to xTB job adapter --- arc/job/adapters/xtb_adapter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/arc/job/adapters/xtb_adapter.py b/arc/job/adapters/xtb_adapter.py index 7328a2ce4c..91ac340160 100644 --- a/arc/job/adapters/xtb_adapter.py +++ b/arc/job/adapters/xtb_adapter.py @@ -218,7 +218,9 @@ def write_input_file(self) -> None: directives, block = '', '' uhf = self.species[0].number_of_radicals or self.multiplicity - 1 - if self.job_type in ['opt', 'conf_opt', 'scan']: + if self.job_type in ['opt', 'conf_opt', 'scan'] \ + or (self.job_type == 'directed_scan' and self.directed_scan_type is not None + and 'opt' in self.directed_scan_type): directives += ' --opt' directives += self.add_accuracy() if self.constraints and self.job_type != 'scan': From ff986a5ec237738884fb3f097d25eb526725540d Mon Sep 17 00:00:00 2001 From: Alon Grinberg Dana Date: Thu, 2 Apr 2026 00:33:17 +0300 Subject: [PATCH 3/4] Added ND scan functionalities to plotter --- arc/plotter.py | 213 +++++++++++++++++++++++++++++++++++++++++++- arc/plotter_test.py | 129 +++++++++++++++++++++++++++ 2 files changed, 339 insertions(+), 3 deletions(-) diff --git a/arc/plotter.py b/arc/plotter.py index d0f6938e84..8934bc0dc3 100644 --- a/arc/plotter.py +++ b/arc/plotter.py @@ -1293,6 +1293,12 @@ def plot_2d_rotor_scan(results: dict, if len(results['scans']) != 2: raise InputError(f'results must represent a 2D rotor, got {len(results["scans"])}D') + # Dispatch to sparse plotting for adaptive scans + if is_sparse_2d_scan(results): + _plot_2d_rotor_scan_sparse(results, path=path, label=label, cmap=cmap, + resolution=resolution, original_dihedrals=original_dihedrals) + return + results['directed_scan'] = clean_scan_results(results['directed_scan']) # phis0 and phis1 correspond to columns and rows in energies, respectively @@ -1357,9 +1363,9 @@ def plot_2d_rotor_scan(results: dict, label = ' for ' + label if label else '' plt.title(f'2D scan energies (kJ/mol){label}') min_x = min_y = -180 - plt.xlim = (min_x, min_x + 360) + plt.gca().set_xlim(min_x, min_x + 360) plt.xticks(np.arange(min_x, min_x + 361, step=60)) - plt.ylim = (min_y, min_y + 360) + plt.gca().set_ylim(min_y, min_y + 360) plt.yticks(np.arange(min_y, min_y + 361, step=60)) if mark_lowest_conformations: @@ -1379,6 +1385,207 @@ def plot_2d_rotor_scan(results: dict, plt.close(fig=fig) +def is_sparse_2d_scan(results: dict) -> bool: + """ + Detect whether a 2D scan results dict represents a sparse/adaptive scan. + + A scan is considered sparse if the results contain + ``sampling_policy == 'adaptive'``. + + Args: + results (dict): The results dictionary from a 2D directed scan. + + Returns: + bool: ``True`` if the scan is sparse/adaptive. + """ + return results.get('sampling_policy') == 'adaptive' + + +def extract_sparse_2d_points(results: dict) -> dict: + """ + Extract sampled point coordinates and energies from a sparse 2D scan result. + + Args: + results (dict): The results dictionary from a 2D directed scan. + + Returns: + dict: A dictionary with keys ``'x'``, ``'y'``, ``'energy'`` (lists of floats for + completed points with non-None energy), plus ``'failed_points'`` and + ``'invalid_points'`` (lists of ``[x, y]`` pairs). + """ + xs, ys, energies = [], [], [] + for key, entry in results.get('directed_scan', {}).items(): + e = entry.get('energy') + if e is not None: + xs.append(float(key[0])) + ys.append(float(key[1])) + energies.append(float(e)) + summary = results.get('adaptive_scan_summary', {}) + return { + 'x': xs, + 'y': ys, + 'energy': energies, + 'failed_points': summary.get('failed_points', []), + 'invalid_points': summary.get('invalid_points', []), + } + + +def interpolate_sparse_2d_scan(points_x: list, + points_y: list, + energies: list, + grid_resolution: float = 2.0, + ) -> tuple: + """ + Interpolate sparse 2D scan data onto a dense grid for contour plotting. + + Uses ``scipy.interpolate.griddata`` with periodic boundary augmentation + to reduce artifacts at the -180/+180 wrap boundary. + + Args: + points_x (list): Sampled dihedral angles for dimension 0 (degrees). + points_y (list): Sampled dihedral angles for dimension 1 (degrees). + energies (list): Energy values at sampled points (kJ/mol). + grid_resolution (float): Spacing of the dense output grid in degrees. + + Returns: + tuple: ``(grid_x, grid_y, grid_energies)`` where each is a 2D numpy array + suitable for ``plt.contourf``. + """ + from scipy.interpolate import griddata + + px = np.array(points_x, dtype=np.float64) + py = np.array(points_y, dtype=np.float64) + pe = np.array(energies, dtype=np.float64) + + # Augment with periodic image points for wrap-around + aug_x, aug_y, aug_e = list(px), list(py), list(pe) + for dx in (-360.0, 0.0, 360.0): + for dy in (-360.0, 0.0, 360.0): + if dx == 0.0 and dy == 0.0: + continue + aug_x.extend(px + dx) + aug_y.extend(py + dy) + aug_e.extend(pe) + aug_x = np.array(aug_x) + aug_y = np.array(aug_y) + aug_e = np.array(aug_e) + + # Dense grid from -180 to 180 + n_pts = int(360.0 / grid_resolution) + 1 + gx = np.linspace(-180.0, 180.0, n_pts) + gy = np.linspace(-180.0, 180.0, n_pts) + grid_x, grid_y = np.meshgrid(gx, gy, indexing='ij') + + # Interpolate: try cubic, fall back to linear, then nearest + pts = np.column_stack([aug_x, aug_y]) + grid_e = None + for method in ('cubic', 'linear'): + try: + grid_e = griddata(pts, aug_e, (grid_x, grid_y), method=method) + if not np.all(np.isnan(grid_e)): + break + except (ValueError, Exception): + grid_e = None + if grid_e is None or np.all(np.isnan(grid_e)): + grid_e = griddata(pts, aug_e, (grid_x, grid_y), method='nearest') + # Fill any remaining NaN with nearest-neighbor + mask = np.isnan(grid_e) + if mask.any(): + grid_nearest = griddata(pts, aug_e, (grid_x, grid_y), method='nearest') + grid_e[mask] = grid_nearest[mask] + + return grid_x, grid_y, grid_e + + +def _plot_2d_rotor_scan_sparse(results: dict, + path: Optional[str] = None, + label: str = '', + cmap: str = 'Blues', + resolution: int = 90, + original_dihedrals: Optional[List[float]] = None, + ): + """ + Plot a sparse/adaptive 2D rotor scan using interpolation for contours + and overlaying sampled, failed, and invalid points. + + This is called internally by :func:`plot_2d_rotor_scan` when the results + are detected as sparse. + + Args: + results (dict): The results dictionary from a 2D directed scan. + path (str, optional): Folder path to save the plot image. + label (str, optional): Species label. + cmap (str, optional): Matplotlib colormap name. + resolution (int, optional): Image DPI. + original_dihedrals (list, optional): Original dihedral angles for marker. + """ + data = extract_sparse_2d_points(results) + xs, ys, energies = data['x'], data['y'], data['energy'] + + if len(xs) < 3: + logger.warning(f'Not enough sparse points to plot 2D scan ({len(xs)} points)') + return + + # Normalize energies to min = 0 + e_min = min(energies) + energies_norm = [e - e_min for e in energies] + + # Interpolate to dense grid + grid_x, grid_y, grid_e = interpolate_sparse_2d_scan(xs, ys, energies_norm, grid_resolution=2.0) + + fig = plt.figure(num=None, figsize=(12, 8), dpi=resolution, facecolor='w', edgecolor='k') + + plt.contourf(grid_x, grid_y, grid_e, 20, cmap=cmap) + plt.colorbar() + contours = plt.contour(grid_x, grid_y, grid_e, 4, colors='black') + plt.clabel(contours, inline=True, fontsize=8) + + # Overlay sampled points + plt.scatter(xs, ys, c='black', s=12, zorder=5, label='sampled') + + # Overlay failed points + failed = data.get('failed_points', []) + if failed: + fx = [p[0] for p in failed] + fy = [p[1] for p in failed] + plt.scatter(fx, fy, c='red', marker='x', s=40, zorder=6, label='failed') + + # Overlay invalid points + invalid = data.get('invalid_points', []) + if invalid: + ix = [p[0] for p in invalid] + iy = [p[1] for p in invalid] + plt.scatter(ix, iy, edgecolors='orange', marker='s', facecolors='none', + s=40, zorder=6, label='invalid') + + # Mark original dihedral + if original_dihedrals is not None and len(original_dihedrals) >= 2: + plt.plot(original_dihedrals[0], original_dihedrals[1], color='r', + marker='.', markersize=15, linewidth=0, label='original') + + plt.xlabel(f'Dihedral 1 for {results["scans"][0]} (degrees)') + plt.ylabel(f'Dihedral 2 for {results["scans"][1]} (degrees)') + label_str = ' for ' + label if label else '' + summary = results.get('adaptive_scan_summary', {}) + n_pts = summary.get('completed_count', len(xs)) + plt.title(f'2D scan energies (kJ/mol){label_str} [adaptive, {n_pts} pts]') + plt.gca().set_xlim(-180, 180) + plt.xticks(np.arange(-180, 181, step=60)) + plt.gca().set_ylim(-180, 180) + plt.yticks(np.arange(-180, 181, step=60)) + + plt.legend(loc='upper right', fontsize=8) + + if path is not None: + fig_name = f'{results["directed_scan_type"]}_{results["scans"]}_adaptive.png' + fig_path = os.path.join(path, fig_name) + plt.savefig(fig_path, dpi=resolution, facecolor='w', edgecolor='w', orientation='portrait', + format='png', transparent=False, bbox_inches=None, pad_inches=0.1, metadata=None) + + plt.show() + plt.close(fig=fig) + + def plot_2d_scan_bond_dihedral(results: dict, path: Optional[str] = None, label: str = '', @@ -1486,7 +1693,7 @@ def plot_2d_scan_bond_dihedral(results: dict, label = ' for ' + label if label else '' plt.title(f'2D scan energies (kJ/mol){label}') min_x = -180 - plt.xlim = (min_x, min_x + 360) + plt.gca().set_xlim(min_x, min_x + 360) plt.xticks(np.arange(min_x, min_x + 361, step=60)) if original_dihedrals is not None: diff --git a/arc/plotter_test.py b/arc/plotter_test.py index ba6984dae4..617077c54c 100644 --- a/arc/plotter_test.py +++ b/arc/plotter_test.py @@ -9,6 +9,8 @@ import shutil import unittest +import numpy as np + import arc.plotter as plotter from arc.common import ARC_PATH, ARC_TESTING_PATH, read_yaml_file, safe_copy_file from arc.species.converter import str_to_xyz @@ -236,5 +238,132 @@ def tearDownClass(cls): os.remove(file_path) +class TestSparse2DPlotting(unittest.TestCase): + """ + Contains unit tests for sparse 2D rotor scan plotting helpers. + """ + + def _make_dense_results(self): + """Helper: build a small dense 2D result dict (4x4 grid, increment=120).""" + directed_scan = {} + for a0 in [0.0, 120.0, -120.0, 0.0]: + for a1 in [0.0, 120.0, -120.0, 0.0]: + key = (f'{a0:.2f}', f'{a1:.2f}') + if key not in directed_scan: + directed_scan[key] = { + 'energy': float(abs(a0) + abs(a1)) / 10.0, + 'xyz': {}, + 'is_isomorphic': True, + 'trsh': [], + } + return { + 'directed_scan_type': 'brute_force_opt', + 'scans': [[1, 2, 3, 4], [5, 6, 7, 8]], + 'directed_scan': directed_scan, + } + + def _make_sparse_results(self, n_points=20): + """Helper: build a sparse adaptive 2D result dict.""" + import random + random.seed(42) + directed_scan = {} + for _ in range(n_points): + a0 = round(random.uniform(-180, 180), 2) + a1 = round(random.uniform(-180, 180), 2) + key = (f'{a0:.2f}', f'{a1:.2f}') + directed_scan[key] = { + 'energy': float(abs(a0) + abs(a1)) / 10.0, + 'xyz': {}, + 'is_isomorphic': True, + 'trsh': [], + } + return { + 'directed_scan_type': 'brute_force_opt', + 'scans': [[1, 2, 3, 4], [5, 6, 7, 8]], + 'directed_scan': directed_scan, + 'sampling_policy': 'adaptive', + 'adaptive_scan_summary': { + 'completed_count': n_points, + 'failed_count': 2, + 'invalid_count': 1, + 'stopping_reason': 'max_points_reached', + 'failed_points': [[45.0, -90.0], [120.0, 60.0]], + 'invalid_points': [[-30.0, 150.0]], + }, + } + + def test_is_sparse_2d_scan_dense(self): + """Test that dense results are not detected as sparse.""" + results = self._make_dense_results() + self.assertFalse(plotter.is_sparse_2d_scan(results)) + + def test_is_sparse_2d_scan_adaptive(self): + """Test that adaptive results are detected as sparse.""" + results = self._make_sparse_results() + self.assertTrue(plotter.is_sparse_2d_scan(results)) + + def test_extract_sparse_2d_points(self): + """Test extraction of sparse point data.""" + results = self._make_sparse_results(15) + data = plotter.extract_sparse_2d_points(results) + self.assertEqual(len(data['x']), len(data['y'])) + self.assertEqual(len(data['x']), len(data['energy'])) + self.assertGreater(len(data['x']), 0) + self.assertEqual(len(data['failed_points']), 2) + self.assertEqual(len(data['invalid_points']), 1) + + def test_extract_sparse_2d_points_dense(self): + """Test extraction from dense results (no adaptive summary).""" + results = self._make_dense_results() + data = plotter.extract_sparse_2d_points(results) + self.assertGreater(len(data['x']), 0) + self.assertEqual(data['failed_points'], []) + self.assertEqual(data['invalid_points'], []) + + def test_interpolate_sparse_2d_scan(self): + """Test interpolation produces a dense grid.""" + xs = [0.0, 90.0, -90.0, 180.0, -180.0, 45.0, -45.0] + ys = [0.0, 90.0, -90.0, 180.0, -180.0, 45.0, -45.0] + es = [0.0, 5.0, 5.0, 10.0, 10.0, 3.0, 3.0] + gx, gy, ge = plotter.interpolate_sparse_2d_scan(xs, ys, es, grid_resolution=10.0) + # Check shapes match + self.assertEqual(gx.shape, gy.shape) + self.assertEqual(gx.shape, ge.shape) + n = int(360.0 / 10.0) + 1 + self.assertEqual(gx.shape, (n, n)) + # No NaN values + self.assertFalse(np.any(np.isnan(ge))) + + def test_plot_sparse_2d_no_crash(self): + """Test that plotting a sparse scan doesn't crash.""" + import tempfile + results = self._make_sparse_results(30) + with tempfile.TemporaryDirectory() as tmpdir: + # Should not raise + plotter.plot_2d_rotor_scan(results, path=tmpdir) + # Check that a file was created + files = os.listdir(tmpdir) + self.assertTrue(any('adaptive' in f for f in files), + f'Expected adaptive plot file, got: {files}') + + def test_plot_dense_2d_unchanged(self): + """Test that plotting a dense scan still works through the legacy path.""" + # This exercises the existing code path; if it crashes, the dense path is broken + results = self._make_dense_results() + # Don't save to disk, just ensure no crash + try: + plotter.plot_2d_rotor_scan(results, path=None) + except (ValueError, KeyError): + # Dense path might fail on this small test grid due to missing points, + # but it should NOT dispatch to sparse path + self.assertFalse(plotter.is_sparse_2d_scan(results)) + + def test_plot_sparse_too_few_points_no_crash(self): + """Test that sparse plotting with < 3 points doesn't crash.""" + results = self._make_sparse_results(2) + # Should not raise, just warn + plotter.plot_2d_rotor_scan(results, path=None) + + if __name__ == '__main__': unittest.main(testRunner=unittest.TextTestRunner(verbosity=2)) From e8f8a8f72a3e755ef1bb333a58f314a38fecfcd6 Mon Sep 17 00:00:00 2001 From: Alon Grinberg Dana Date: Thu, 2 Apr 2026 00:33:29 +0300 Subject: [PATCH 4/4] Adaptations to Scheduler --- arc/scheduler.py | 344 +++++++++++++++++++++++++----------------- arc/scheduler_test.py | 245 +++++++++++++++++++++++++++++- 2 files changed, 448 insertions(+), 141 deletions(-) diff --git a/arc/scheduler.py b/arc/scheduler.py index 21f5a2a7a7..b755ae6760 100644 --- a/arc/scheduler.py +++ b/arc/scheduler.py @@ -4,7 +4,6 @@ """ import datetime -import itertools import os import pprint import shutil @@ -55,6 +54,28 @@ xyz_to_coords_list, xyz_to_str, ) +from arc.species.nd_scan import (decrement_running_jobs, + finalize_directed_scan_results, + get_continuous_scan_dihedrals, + get_pending_adaptive_points, + get_rotor_dict_by_pivots, + get_torsion_dihedral_grid, + increment_continuous_scan_indices, + initialize_adaptive_scan_state, + initialize_continuous_scan_state, + is_adaptive_enabled, + is_adaptive_scan_complete, + is_continuous_scan_complete, + iter_brute_force_scan_points, + mark_scan_point_completed, + mark_scan_point_failed, + mark_scan_point_invalid, + mark_scan_points_pending, + record_directed_scan_point, + select_next_adaptive_points, + should_continue_adaptive_scan, + validate_scan_resolution, + ) from arc.species.perceive import perceive_molecule_from_xyz from arc.species.vectors import get_angle, calculate_dihedral_angle @@ -659,13 +680,26 @@ def schedule_jobs(self): self.spawn_directed_scan_jobs(label=label, rotor_index=job.rotor_index, xyz=xyz) if 'brute_force' in job.directed_scan_type: # Just terminated a brute_force directed scan job. - # Are there additional jobs of the same type currently running for this species? - self.species_dict[label].rotors_dict[job.rotor_index]['number_of_running_jobs'] -= 1 - if not self.species_dict[label].rotors_dict[job.rotor_index]['number_of_running_jobs']: - # All brute force scan jobs for these pivots terminated. - logger.info(f'\nAll brute force directed scan jobs for species {label} between ' - f'pivots {job.pivots} successfully terminated.\n') - self.process_directed_scans(label, pivots=job.pivots) + rotor_dict = self.species_dict[label].rotors_dict[job.rotor_index] + all_done = decrement_running_jobs(rotor_dict) + if all_done: + if is_adaptive_enabled(rotor_dict): + # Adaptive: check if more batches are needed + if is_adaptive_scan_complete(rotor_dict, rotor_scan_resolution): + logger.info(f'\nAdaptive scan for species {label} between ' + f'pivots {job.pivots} complete: ' + f'{rotor_dict["adaptive_scan"]["stopping_reason"]}.\n') + self.process_directed_scans(label, pivots=job.pivots) + else: + # Spawn next adaptive batch + self.spawn_directed_scan_jobs( + label=label, rotor_index=job.rotor_index) + else: + # Dense: all brute force scan jobs for these pivots terminated. + logger.info(f'\nAll brute force directed scan jobs for species ' + f'{label} between pivots {job.pivots} ' + f'successfully terminated.\n') + self.process_directed_scans(label, pivots=job.pivots) shutil.rmtree(job.local_path, ignore_errors=True) self.timer = False break @@ -1627,8 +1661,7 @@ def spawn_directed_scan_jobs(self, SchedulerError: If the rotor scan resolution as defined in settings.py is illegal. """ increment = rotor_scan_resolution - if divmod(360, increment)[1]: - raise SchedulerError(f'The directed scan got an illegal scan resolution of {increment}') + validate_scan_resolution(increment) torsions = self.species_dict[label].rotors_dict[rotor_index]['torsion'] directed_scan_type = self.species_dict[label].rotors_dict[rotor_index]['directed_scan_type'] xyz = xyz or self.species_dict[label].get_xyz(generate=True) @@ -1648,20 +1681,16 @@ def spawn_directed_scan_jobs(self, ) elif 'brute' in directed_scan_type: - # spawn jobs all at once - dihedrals = dict() - - for torsion in torsions: - original_dihedral = get_angle_in_180_range(calculate_dihedral_angle(coords=xyz['coords'], - torsion=torsion, - index=0)) - dihedrals[tuple(torsion)] = [get_angle_in_180_range(original_dihedral + i * increment) for i in - range(int(360 / increment) + 1)] - modified_xyz = xyz - if 'diagonal' not in directed_scan_type: - # increment dihedrals one by one (resulting in an ND scan) - all_dihedral_combinations = list(itertools.product(*[dihedrals[tuple(torsion)] for torsion in torsions])) - for dihedral_tuple in all_dihedral_combinations: + rotor_dict = self.species_dict[label].rotors_dict[rotor_index] + if is_adaptive_enabled(rotor_dict): + # Adaptive sparse brute-force submission + self._spawn_adaptive_brute_force_jobs(label, rotor_index, xyz, increment) + else: + # Dense full-grid brute-force submission (legacy behavior) + dihedrals = get_torsion_dihedral_grid(xyz=xyz, torsions=torsions, increment=increment) + is_diagonal = 'diagonal' in directed_scan_type + modified_xyz = xyz + for dihedral_tuple in iter_brute_force_scan_points(dihedrals, torsions, diagonal=is_diagonal): for torsion, dihedral in zip(torsions, dihedral_tuple): self.species_dict[label].set_dihedral(scan=torsion, index=0, @@ -1669,7 +1698,7 @@ def spawn_directed_scan_jobs(self, count=False, xyz=modified_xyz) modified_xyz = self.species_dict[label].initial_xyz - self.species_dict[label].rotors_dict[rotor_index]['number_of_running_jobs'] += 1 + rotor_dict['number_of_running_jobs'] += 1 self.run_job(label=label, xyz=modified_xyz, level_of_theory=self.scan_level, @@ -1679,44 +1708,14 @@ def spawn_directed_scan_jobs(self, dihedrals=list(dihedral_tuple), rotor_index=rotor_index, ) - else: - # increment all dihedrals at once (resulting in a unique 1D scan along several changing dimensions) - for i in range(len(dihedrals[tuple(torsions[0])])): - for torsion in torsions: - dihedral = dihedrals[tuple(torsion)][i] - self.species_dict[label].set_dihedral(scan=torsion, - index=0, - deg_abs=dihedral, - count=False, - xyz=modified_xyz) - modified_xyz = self.species_dict[label].initial_xyz - dihedrals = [dihedrals[tuple(torsion)][i] for torsion in torsions] - self.species_dict[label].rotors_dict[rotor_index]['number_of_running_jobs'] += 1 - self.run_job(label=label, - xyz=modified_xyz, - level_of_theory=self.scan_level, - job_type='directed_scan', - directed_scan_type=directed_scan_type, - torsions=torsions, - dihedrals=dihedrals, - rotor_index=rotor_index, - ) elif 'cont' in directed_scan_type: # spawn jobs one by one - if not len(self.species_dict[label].rotors_dict[rotor_index]['cont_indices']): - self.species_dict[label].rotors_dict[rotor_index]['cont_indices'] = [0] * len(torsions) - if not len(self.species_dict[label].rotors_dict[rotor_index]['original_dihedrals']): - self.species_dict[label].rotors_dict[rotor_index]['original_dihedrals'] = \ - [f'{calculate_dihedral_angle(coords=xyz["coords"], torsion=scan, index=1):.2f}' - for scan in self.species_dict[label].rotors_dict[rotor_index]['scan']] # stores as str for YAML rotor_dict = self.species_dict[label].rotors_dict[rotor_index] + initialize_continuous_scan_state(rotor_dict, xyz) torsions = rotor_dict['torsion'] - max_num = 360 / increment + 1 # dihedral angles per scan - original_dihedrals = list() - for dihedral in rotor_dict['original_dihedrals']: - original_dihedrals.append(get_angle_in_180_range(dihedral)) - if not any(self.species_dict[label].rotors_dict[rotor_index]['cont_indices']): + original_dihedrals = [get_angle_in_180_range(float(d)) for d in rotor_dict['original_dihedrals']] + if not any(rotor_dict['cont_indices']): # This is the first call for this cont_opt directed rotor, spawn the first job w/o changing dihedrals. self.run_job(label=label, xyz=self.species_dict[label].final_xyz, @@ -1727,7 +1726,7 @@ def spawn_directed_scan_jobs(self, dihedrals=original_dihedrals, rotor_index=rotor_index, ) - self.species_dict[label].rotors_dict[rotor_index]['cont_indices'][0] += 1 + rotor_dict['cont_indices'][0] += 1 return else: # this is NOT the first call for this cont_opt directed rotor, check that ``xyz`` was given. @@ -1735,26 +1734,16 @@ def spawn_directed_scan_jobs(self, # xyz is None only at the first time cont opt is spawned, where cont_index is [0, 0,... 0]. raise InputError('xyz argument must be given for a continuous scan job') # check whether this rotor is done - if self.species_dict[label].rotors_dict[rotor_index]['cont_indices'][-1] == max_num - 1: # 0-indexed + if is_continuous_scan_complete(rotor_dict, increment): # no more counters to increment, all done! logger.info(f'Completed all jobs for the continuous directed rotor scan for species {label} ' f'between pivots {rotor_dict["pivots"]}') self.process_directed_scans(label, rotor_dict['pivots']) return + dihedrals = get_continuous_scan_dihedrals(rotor_dict, increment) modified_xyz = xyz - dihedrals = list() - for index, (original_dihedral, torsion) in enumerate(zip(original_dihedrals, torsions)): - dihedral = original_dihedral + \ - self.species_dict[label].rotors_dict[rotor_index]['cont_indices'][index] * increment - # Change the original dihedral so we won't end up with two calcs for 180.0, but none for -180.0 - # (it only matters for plotting, the geometry is of course the same) - dihedral = get_angle_in_180_range(dihedral) - dihedrals.append(dihedral) - # Only change the dihedrals in the xyz if this torsion corresponds to the current index, - # or if this is a diagonal scan. - # Species.set_dihedral() uses .final_xyz or the given xyz to modify the .initial_xyz - # attribute to the desired dihedral. + for torsion, dihedral in zip(torsions, dihedrals): self.species_dict[label].set_dihedral(scan=torsion, index=0, deg_abs=dihedral, @@ -1771,19 +1760,75 @@ def spawn_directed_scan_jobs(self, rotor_index=rotor_index, ) - if 'diagonal' in directed_scan_type: - # increment ALL counters for a diagonal scan - self.species_dict[label].rotors_dict[rotor_index]['cont_indices'] = \ - [self.species_dict[label].rotors_dict[rotor_index]['cont_indices'][0] + 1] * len(torsions) - else: - # increment the counter sequentially (non-diagonal scan) - for index in range(len(torsions)): - if self.species_dict[label].rotors_dict[rotor_index]['cont_indices'][index] < max_num - 1: - self.species_dict[label].rotors_dict[rotor_index]['cont_indices'][index] += 1 - break - elif (self.species_dict[label].rotors_dict[rotor_index]['cont_indices'][index] == max_num - 1 - and index < len(torsions) - 1): - self.species_dict[label].rotors_dict[rotor_index]['cont_indices'][index] = 0 + increment_continuous_scan_indices(rotor_dict, increment, + diagonal='diagonal' in directed_scan_type) + + def _spawn_adaptive_brute_force_jobs(self, + label: str, + rotor_index: int, + xyz: dict, + increment: float, + ): + """ + Spawn the next batch of adaptive brute-force scan jobs. + + This is called from within the brute-force path of ``spawn_directed_scan_jobs`` + when adaptive execution policy is enabled. It selects points via the adaptive + helpers in ``nd_scan.py`` and submits them using the same job machinery as + dense brute-force. + + Args: + label (str): The species label. + rotor_index (int): The 0-indexed rotor number in ``species.rotors_dict``. + xyz (dict): The 3D coordinates for building constrained geometries. + increment (float): The scan resolution in degrees. + """ + rotor_dict = self.species_dict[label].rotors_dict[rotor_index] + torsions = rotor_dict['torsion'] + directed_scan_type = rotor_dict['directed_scan_type'] + + # Initialize adaptive state if this is the first call + initialize_adaptive_scan_state(rotor_dict, xyz, increment) + + # Check if we should stop + if not should_continue_adaptive_scan(rotor_dict, increment): + if not get_pending_adaptive_points(rotor_dict): + logger.info(f'Adaptive scan for species {label} between pivots {rotor_dict["pivots"]} ' + f'is complete: {rotor_dict["adaptive_scan"]["stopping_reason"]}') + self.process_directed_scans(label, rotor_dict['pivots']) + return + + # Select the next batch + points = select_next_adaptive_points(rotor_dict, increment) + if not points: + if not get_pending_adaptive_points(rotor_dict): + rotor_dict['adaptive_scan']['stopping_reason'] = 'no_candidates' + logger.info(f'Adaptive scan for species {label} between pivots {rotor_dict["pivots"]} ' + f'is complete: no_candidates') + self.process_directed_scans(label, rotor_dict['pivots']) + return + + # Mark them pending and submit jobs + mark_scan_points_pending(rotor_dict, points) + for point in points: + modified_xyz = xyz + for torsion, dihedral in zip(torsions, point): + self.species_dict[label].set_dihedral(scan=torsion, + index=0, + deg_abs=dihedral, + count=False, + xyz=modified_xyz) + modified_xyz = self.species_dict[label].initial_xyz + rotor_dict['number_of_running_jobs'] += 1 + self.run_job(label=label, + xyz=modified_xyz, + level_of_theory=self.scan_level, + job_type='directed_scan', + directed_scan_type=directed_scan_type, + torsions=torsions, + dihedrals=list(point), + rotor_index=rotor_index, + ) def process_directed_scans(self, label: str, pivots: Union[List[int], List[List[int]]]): """ @@ -1793,52 +1838,34 @@ def process_directed_scans(self, label: str, pivots: Union[List[int], List[List[ label (str): The species label. pivots (Union[List[int], List[List[int]]]): The rotor pivots. """ - for rotor_dict_index in self.species_dict[label].rotors_dict.keys(): - rotor_dict = self.species_dict[label].rotors_dict[rotor_dict_index] # avoid modifying the iterator - if rotor_dict['pivots'] == pivots: - # identified a directed scan (either continuous or brute force, they're treated the same here) - dihedrals = [[float(dihedral) for dihedral in dihedral_string_tuple] - for dihedral_string_tuple in rotor_dict['directed_scan'].keys()] - sorted_dihedrals = sorted(dihedrals) - min_energy = extremum_list([directed_scan_dihedral['energy'] - for directed_scan_dihedral in rotor_dict['directed_scan'].values()], - return_min=True) - trshed_points = 0 - if rotor_dict['directed_scan_type'] == 'ess': - # parse the single output file - results = parser.parse_nd_scan_energies(log_file_path=rotor_dict['scan_path'])[0] - else: - results = {'directed_scan_type': rotor_dict['directed_scan_type'], - 'scans': rotor_dict['scan'], - 'directed_scan': rotor_dict['directed_scan']} - for dihedral_list in sorted_dihedrals: - dihedrals_key = tuple(f'{dihedral:.2f}' for dihedral in dihedral_list) - dihedral_dict = results['directed_scan'][dihedrals_key] - if dihedral_dict['trsh']: - trshed_points += 1 - if dihedral_dict['energy'] is not None: - dihedral_dict['energy'] -= min_energy # set 0 at the minimal energy - folder_name = 'rxns' if self.species_dict[label].is_ts else 'Species' - rotor_yaml_file_path = os.path.join(self.project_directory, 'output', folder_name, label, 'rotors', - f'{pivots}_{rotor_dict["directed_scan_type"]}.yml') - plotter.save_nd_rotor_yaml(results, path=rotor_yaml_file_path) - self.species_dict[label].rotors_dict[rotor_dict_index]['scan_path'] = rotor_yaml_file_path - if trshed_points: - logger.warning(f'Directed rotor scan for species {label} between pivots {rotor_dict["pivots"]} ' - f'had {trshed_points} points that required optimization troubleshooting.') - rotor_path = os.path.join(self.project_directory, 'output', folder_name, label, 'rotors') - if len(results['scans']) == 1: - plotter.plot_1d_rotor_scan( - results=results, - path=rotor_path, - scan=rotor_dict['scan'][0], - label=label, - original_dihedral=self.species_dict[label].rotors_dict[rotor_dict_index]['original_dihedrals'], - ) - elif len(results['scans']) == 2: - plotter.plot_2d_rotor_scan(results=results, path=rotor_path) - else: - logger.debug('Not plotting ND rotors with N > 2') + match = get_rotor_dict_by_pivots(self.species_dict[label].rotors_dict, pivots) + if match is not None: + rotor_dict_index, rotor_dict = match + # identified a directed scan (either continuous or brute force, they're treated the same here) + results, trshed_points = finalize_directed_scan_results( + rotor_dict, parse_nd_scan_energies_func=parser.parse_nd_scan_energies, + increment=rotor_scan_resolution) + folder_name = 'rxns' if self.species_dict[label].is_ts else 'Species' + rotor_yaml_file_path = os.path.join(self.project_directory, 'output', folder_name, label, 'rotors', + f'{pivots}_{rotor_dict["directed_scan_type"]}.yml') + plotter.save_nd_rotor_yaml(results, path=rotor_yaml_file_path) + self.species_dict[label].rotors_dict[rotor_dict_index]['scan_path'] = rotor_yaml_file_path + if trshed_points: + logger.warning(f'Directed rotor scan for species {label} between pivots {rotor_dict["pivots"]} ' + f'had {trshed_points} points that required optimization troubleshooting.') + rotor_path = os.path.join(self.project_directory, 'output', folder_name, label, 'rotors') + if len(results['scans']) == 1: + plotter.plot_1d_rotor_scan( + results=results, + path=rotor_path, + scan=rotor_dict['scan'][0], + label=label, + original_dihedral=self.species_dict[label].rotors_dict[rotor_dict_index]['original_dihedrals'], + ) + elif len(results['scans']) == 2: + plotter.plot_2d_rotor_scan(results=results, path=rotor_path) + else: + logger.debug('Not plotting ND rotors with N > 2') def process_conformers(self, label): """ @@ -3024,19 +3051,56 @@ def check_directed_scan_job(self, label: str, job: 'JobAdapter'): if job.job_status[1]['status'] == 'done': xyz = parser.parse_geometry(log_file_path=job.local_path_to_output_file) is_isomorphic = self.species_dict[label].check_xyz_isomorphism(xyz=xyz, verbose=False) - for rotor_dict in self.species_dict[label].rotors_dict.values(): - if rotor_dict['pivots'] == job.pivots: - key = tuple(f'{dihedral:.2f}' for dihedral in job.dihedrals) - rotor_dict['directed_scan'][key] = {'energy': parser.parse_e_elect( - path=job.local_path_to_output_file), - 'xyz': xyz, - 'is_isomorphic': is_isomorphic, - 'trsh': job.ess_trsh_methods, - } + energy = parser.parse_e_elect(log_file_path=job.local_path_to_output_file) + match = get_rotor_dict_by_pivots(self.species_dict[label].rotors_dict, job.pivots) + if match is not None: + _, rotor_dict = match + if is_adaptive_enabled(rotor_dict): + # Record into adaptive bookkeeping and legacy directed_scan. + # Route non-isomorphic points to invalid tracking. + if is_isomorphic: + mark_scan_point_completed( + rotor_dict=rotor_dict, + point=job.dihedrals, + energy=energy, + xyz=xyz, + is_isomorphic=is_isomorphic, + trsh=job.ess_trsh_methods, + ) + else: + mark_scan_point_invalid(rotor_dict, job.dihedrals, reason='non-isomorphic') + record_directed_scan_point( + rotor_dict=rotor_dict, + dihedrals=job.dihedrals, + energy=energy, + xyz=xyz, + is_isomorphic=False, + trsh=job.ess_trsh_methods, + ) + else: + record_directed_scan_point( + rotor_dict=rotor_dict, + dihedrals=job.dihedrals, + energy=energy, + xyz=xyz, + is_isomorphic=is_isomorphic, + trsh=job.ess_trsh_methods, + ) else: + # Snapshot running jobs before troubleshooting to detect resubmission. + jobs_before = set(self.running_jobs.get(label, [])) self.troubleshoot_ess(label=label, job=job, level_of_theory=self.scan_level) + # Only mark the scan point as failed if troubleshooting did NOT + # resubmit the job. Compare running_jobs before/after to detect + # whether a new scan job was added for this specific label. + match = get_rotor_dict_by_pivots(self.species_dict[label].rotors_dict, job.pivots) + if match is not None and is_adaptive_enabled(match[1]): + jobs_after = set(self.running_jobs.get(label, [])) + resubmitted = len(jobs_after - jobs_before) > 0 + if not resubmitted: + mark_scan_point_failed(match[1], job.dihedrals) def check_all_done(self, label: str): """ diff --git a/arc/scheduler_test.py b/arc/scheduler_test.py index 3216a9f254..5a19cafc23 100644 --- a/arc/scheduler_test.py +++ b/arc/scheduler_test.py @@ -10,17 +10,24 @@ import shutil import arc.parser.parser as parser +import arc.scheduler as sched_module from arc.checks.ts import check_ts from arc.common import ARC_PATH, ARC_TESTING_PATH, almost_equal_coords_lists, initialize_job_types, read_yaml_file from arc.job.factory import job_factory from arc.level import Level from arc.plotter import save_conformers_file -from arc.scheduler import Scheduler, species_has_freq, species_has_geo, species_has_sp, species_has_sp_and_freq from arc.imports import settings from arc.reaction import ARCReaction from arc.species.converter import str_to_xyz +from arc.species.nd_scan import decrement_running_jobs, is_adaptive_enabled, is_adaptive_scan_complete, point_to_key from arc.species.species import ARCSpecies +Scheduler = sched_module.Scheduler +species_has_freq = sched_module.species_has_freq +species_has_geo = sched_module.species_has_geo +species_has_sp = sched_module.species_has_sp +species_has_sp_and_freq = sched_module.species_has_sp_and_freq + default_levels_of_theory = settings['default_levels_of_theory'] @@ -769,5 +776,241 @@ def tearDownClass(cls): shutil.rmtree(project_directory, ignore_errors=True) +class TestDirectedScanFunctional(unittest.TestCase): + """ + Functional tests for directed scan workflows using the xTB adapter. + These exercise the real scheduler -> job -> parser pipeline on small molecules. + + Uses a coarse 120-degree scan resolution (4 grid points per dimension) + so tests complete in seconds, not hours. + """ + _original_resolution = None + _xtb_available = None + + @classmethod + def setUpClass(cls): + """Set up a Scheduler with xTB scan level and an ethanol species.""" + try: + from arc.job.adapters.xtb_adapter import xTBAdapter + import tempfile + td = tempfile.mkdtemp() + job = xTBAdapter(execution_type='incore', job_type='sp', project='xtb_check', + project_directory=td, species=[ARCSpecies(label='H2', smiles='[H][H]')]) + job.execute_incore() + cls._xtb_available = os.path.isfile(job.local_path_to_output_file) + shutil.rmtree(td, ignore_errors=True) + except Exception: + cls._xtb_available = False + if not cls._xtb_available: + return + cls.maxDiff = None + cls.project = 'arc_directed_scan_functional_test' + cls.project_directory = os.path.join(ARC_PATH, 'Projects', cls.project) + cls.ess_settings = {'gaussian': ['server1']} + xtb_level = Level(method='gfn2') + job_types = initialize_job_types({'rotors': True, 'opt': False, 'freq': False, + 'sp': False, 'conf_opt': False, 'conf_sp': False}) + + cls.spc = ARCSpecies(label='EtOH', smiles='CCO') + cls.spc.initial_xyz = cls.spc.get_xyz() + cls.spc.final_xyz = cls.spc.initial_xyz + cls.spc.determine_rotors() + + # Patch scan resolution to 120 degrees for fast tests (4 points per dim, 16 total for 2D) + cls._original_resolution = sched_module.rotor_scan_resolution + sched_module.rotor_scan_resolution = 120.0 + + cls.sched = Scheduler(project=cls.project, + ess_settings=cls.ess_settings, + species_list=[cls.spc], + composite_method=None, + conformer_opt_level=xtb_level, + opt_level=xtb_level, + freq_level=xtb_level, + sp_level=xtb_level, + scan_level=xtb_level, + ts_guess_level=xtb_level, + project_directory=cls.project_directory, + testing=True, + job_types=job_types, + orbitals_level=xtb_level, + adaptive_levels=None) + + def _skip_if_no_xtb(self): + """Skip the test if xTB is not available.""" + if not self._xtb_available: + self.skipTest('xTB is not installed or not available') + + def _process_completed_brute_force_jobs(self, label, rotor_index): + """ + Helper: after spawn_directed_scan_jobs has submitted (and executed incore) + all brute-force jobs, iterate through them calling check_directed_scan_job + and decrement_running_jobs to simulate the scheduler main loop. + + Uses an iterative loop instead of recursion to avoid counter desync and + stack overflow for adaptive scans that spawn multiple batches. + """ + max_rounds = 20 + increment = sched_module.rotor_scan_resolution + rotor_dict = self.sched.species_dict[label].rotors_dict[rotor_index] + + for _ in range(max_rounds): + if 'directed_scan' not in self.sched.job_dict[label]: + return + # Snapshot the current batch of jobs to process + processed_names = set() + jobs_batch = [(name, job) for name, job in self.sched.job_dict[label]['directed_scan'].items() + if job.rotor_index == rotor_index and name not in processed_names] + for job_name, job in jobs_batch: + processed_names.add(job_name) + self.sched.check_directed_scan_job(label=label, job=job) + if 'brute_force' in job.directed_scan_type: + decrement_running_jobs(rotor_dict) + # After processing this batch, check if adaptive needs more + if is_adaptive_enabled(rotor_dict) \ + and rotor_dict['number_of_running_jobs'] == 0 \ + and not is_adaptive_scan_complete(rotor_dict, increment): + self.sched.spawn_directed_scan_jobs(label=label, rotor_index=rotor_index) + continue # process the new batch + return # done + + def _make_2d_rotor(self, label, scan_type='brute_force_opt', adaptive=False): + """ + Helper: create a 2D rotor combining both ethanol rotors and add it to the species. + Returns the rotor index. + """ + spc = self.sched.species_dict[label] + rotor_0 = spc.rotors_dict[0] + rotor_1 = spc.rotors_dict[1] + rotor_2d_index = max(spc.rotors_dict.keys()) + 1 + rotor_2d = { + 'pivots': [rotor_0['pivots'], rotor_1['pivots']], + 'top': [rotor_0['top'], rotor_1['top']], + 'scan': [rotor_0['scan'], rotor_1['scan']], + 'torsion': [rotor_0['torsion'], rotor_1['torsion']], + 'number_of_running_jobs': 0, + 'success': None, + 'invalidation_reason': '', + 'times_dihedral_set': 0, + 'trsh_counter': 0, + 'trsh_methods': [], + 'scan_path': '', + 'directed_scan_type': scan_type, + 'directed_scan': {}, + 'dimensions': 2, + 'original_dihedrals': [], + 'cont_indices': [], + 'symmetry': None, + 'max_e': None, + } + if adaptive: + rotor_2d['sampling_policy'] = 'adaptive' + spc.rotors_dict[rotor_2d_index] = rotor_2d + self.sched.job_dict[label]['directed_scan'] = {} + return rotor_2d_index + + def test_dense_2d_brute_force_scan(self): + """Functional test: dense 2D brute_force_opt scan on ethanol 2D rotor via xTB.""" + self._skip_if_no_xtb() + label = 'EtOH' + spc = self.sched.species_dict[label] + rotor_2d_index = self._make_2d_rotor(label, scan_type='brute_force_opt') + rotor_2d = spc.rotors_dict[rotor_2d_index] + + self.sched.spawn_directed_scan_jobs(label=label, rotor_index=rotor_2d_index) + + n_jobs = rotor_2d['number_of_running_jobs'] + self.assertGreater(n_jobs, 0, 'No 2D brute-force jobs were spawned') + + self._process_completed_brute_force_jobs(label, rotor_2d_index) + + directed = rotor_2d['directed_scan'] + self.assertGreater(len(directed), 0, 'No 2D directed scan entries recorded') + for key, entry in directed.items(): + self.assertIsInstance(key, tuple) + self.assertEqual(len(key), 2, 'Expected 2D scan keys') + self.assertIsNotNone(entry['energy'], f'Energy is None for key {key}') + + del spc.rotors_dict[rotor_2d_index] + + def test_adaptive_2d_brute_force_scan(self): + """Functional test: adaptive 2D brute_force_opt scan on ethanol via xTB.""" + self._skip_if_no_xtb() + label = 'EtOH' + spc = self.sched.species_dict[label] + rotor_2d_index = self._make_2d_rotor(label, scan_type='brute_force_opt', adaptive=True) + rotor_2d = spc.rotors_dict[rotor_2d_index] + self.assertTrue(is_adaptive_enabled(rotor_2d)) + + self.sched.spawn_directed_scan_jobs(label=label, rotor_index=rotor_2d_index) + + n_jobs = rotor_2d['number_of_running_jobs'] + self.assertGreater(n_jobs, 0, 'No adaptive jobs were spawned') + + # Adaptive state should have been initialized + self.assertIn('adaptive_scan', rotor_2d) + state = rotor_2d['adaptive_scan'] + self.assertTrue(state['enabled']) + self.assertGreater(len(state['seed_points']), 0) + + self._process_completed_brute_force_jobs(label, rotor_2d_index) + + # Verify adaptive bookkeeping + state = rotor_2d['adaptive_scan'] + self.assertGreater(len(state['completed_points']), 0) + self.assertEqual(len(state['pending_points']), 0) + + # Verify legacy directed_scan is populated + directed = rotor_2d['directed_scan'] + self.assertGreater(len(directed), 0) + for key, entry in directed.items(): + self.assertEqual(len(key), 2, 'Expected 2D scan keys') + self.assertIsNotNone(entry['energy']) + + # Check that every completed adaptive point is in legacy directed_scan + for pt in state['completed_points']: + key = point_to_key(tuple(pt)) + self.assertIn(key, directed, f'Completed point {pt} not in directed_scan') + + # If validation ran, check its structure + if 'validation' in state: + val = state['validation'] + self.assertIn('status', val) + self.assertIn('thresholds', val) + + del spc.rotors_dict[rotor_2d_index] + + def test_dense_2d_brute_force_sp_scan(self): + """Functional test: dense 2D brute_force_sp scan on ethanol via xTB.""" + self._skip_if_no_xtb() + label = 'EtOH' + spc = self.sched.species_dict[label] + rotor_2d_index = self._make_2d_rotor(label, scan_type='brute_force_sp') + rotor_2d = spc.rotors_dict[rotor_2d_index] + + self.sched.spawn_directed_scan_jobs(label=label, rotor_index=rotor_2d_index) + + n_jobs = rotor_2d['number_of_running_jobs'] + self.assertGreater(n_jobs, 0, 'No 2D brute-force SP jobs were spawned') + + self._process_completed_brute_force_jobs(label, rotor_2d_index) + + directed = rotor_2d['directed_scan'] + self.assertGreater(len(directed), 0, 'No 2D SP scan entries recorded') + for key, entry in directed.items(): + self.assertEqual(len(key), 2) + self.assertIsNotNone(entry['energy']) + + del spc.rotors_dict[rotor_2d_index] + + @classmethod + def tearDownClass(cls): + """Clean up project directory and restore scan resolution.""" + if cls._original_resolution is not None: + sched_module.rotor_scan_resolution = cls._original_resolution + if hasattr(cls, 'project_directory'): + shutil.rmtree(cls.project_directory, ignore_errors=True) + + if __name__ == '__main__': unittest.main(testRunner=unittest.TextTestRunner(verbosity=2))