diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index 62de61ac6a..de5b2ef1d5 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -48,6 +48,18 @@ from deepmd.utils.data_system import ( process_systems, ) +from deepmd.utils.eval_metrics import ( + DP_TEST_HESSIAN_METRIC_KEYS, + DP_TEST_SPIN_WEIGHTED_METRIC_KEYS, + DP_TEST_WEIGHTED_FORCE_METRIC_KEYS, + DP_TEST_WEIGHTED_METRIC_KEYS, + compute_energy_type_metrics, + compute_error_stat, + compute_spin_force_metrics, + compute_weighted_error_stat, + mae, + rmse, +) from deepmd.utils.weight_avg import ( weighted_average, ) @@ -244,38 +256,6 @@ def test( log.info("# ----------------------------------------------- ") -def mae(diff: np.ndarray) -> float: - """Calcalte mean absulote error. - - Parameters - ---------- - diff : np.ndarray - difference - - Returns - ------- - float - mean absulote error - """ - return np.mean(np.abs(diff)) - - -def rmse(diff: np.ndarray) -> float: - """Calculate root mean square error. - - Parameters - ---------- - diff : np.ndarray - difference - - Returns - ------- - float - root mean square error - """ - return np.sqrt(np.average(diff * diff)) - - def save_txt_file( fname: Path, data: np.ndarray, header: str = "", append: bool = False ) -> None: @@ -297,6 +277,231 @@ def save_txt_file( np.savetxt(fp, data, header=header) +def _reshape_force_by_atom(force_array: np.ndarray, natoms: int) -> np.ndarray: + """Reshape flattened force arrays into `[nframes, natoms, 3]`.""" + return np.reshape(force_array, [-1, natoms, 3]) + + +def _concat_force_rows( + force_blocks: list[np.ndarray], dtype: np.dtype | type[np.generic] +) -> np.ndarray: + """Concatenate per-frame force rows into one 2D array.""" + if not force_blocks: + return np.empty((0, 3), dtype=dtype) + return np.concatenate(force_blocks, axis=0) + + +def _align_spin_force_arrays( + *, + dp: "DeepPot", + atype: np.ndarray, + natoms: int, + prediction_force: np.ndarray, + reference_force: np.ndarray, + prediction_force_mag: np.ndarray | None, + reference_force_mag: np.ndarray | None, + mask_mag: np.ndarray | None, +) -> tuple[np.ndarray, np.ndarray, np.ndarray | None, np.ndarray | None]: + """Align spin force arrays into real-atom and magnetic subsets.""" + prediction_force_by_atom = _reshape_force_by_atom(prediction_force, natoms) + reference_force_by_atom = _reshape_force_by_atom(reference_force, natoms) + if dp.get_ntypes_spin() != 0: # old tf support for spin + ntypes_real = dp.get_ntypes() - dp.get_ntypes_spin() + atype_by_frame = np.reshape(atype, [-1, natoms]) + if atype_by_frame.shape[0] == 1 and prediction_force_by_atom.shape[0] != 1: + atype_by_frame = np.broadcast_to( + atype_by_frame, + (prediction_force_by_atom.shape[0], natoms), + ) + if atype_by_frame.shape[0] != prediction_force_by_atom.shape[0]: + raise ValueError( + "Spin atom types and force arrays must have matching frames." + ) + force_real_prediction_chunks = [] + force_real_reference_chunks = [] + force_magnetic_prediction_chunks = [] + force_magnetic_reference_chunks = [] + for frame_atype, frame_prediction, frame_reference in zip( + atype_by_frame, + prediction_force_by_atom, + reference_force_by_atom, + strict=False, + ): + real_mask = frame_atype < ntypes_real + magnetic_mask = ~real_mask + force_real_prediction_chunks.append(frame_prediction[real_mask]) + force_real_reference_chunks.append(frame_reference[real_mask]) + force_magnetic_prediction_chunks.append(frame_prediction[magnetic_mask]) + force_magnetic_reference_chunks.append(frame_reference[magnetic_mask]) + return ( + _concat_force_rows( + force_real_prediction_chunks, + prediction_force_by_atom.dtype, + ), + _concat_force_rows( + force_real_reference_chunks, + reference_force_by_atom.dtype, + ), + _concat_force_rows( + force_magnetic_prediction_chunks, + prediction_force_by_atom.dtype, + ), + _concat_force_rows( + force_magnetic_reference_chunks, + reference_force_by_atom.dtype, + ), + ) + + force_real_prediction = prediction_force_by_atom.reshape(-1, 3) + force_real_reference = reference_force_by_atom.reshape(-1, 3) + if prediction_force_mag is None or reference_force_mag is None or mask_mag is None: + return force_real_prediction, force_real_reference, None, None + magnetic_mask = mask_mag.reshape(-1).astype(bool) + return ( + force_real_prediction, + force_real_reference, + prediction_force_mag.reshape(-1, 3)[magnetic_mask], + reference_force_mag.reshape(-1, 3)[magnetic_mask], + ) + + +def _write_energy_test_details( + *, + detail_path: Path, + system: str, + natoms: int, + append_detail: bool, + reference_energy: np.ndarray, + prediction_energy: np.ndarray, + reference_force: np.ndarray, + prediction_force: np.ndarray, + reference_virial: np.ndarray | None, + prediction_virial: np.ndarray | None, + out_put_spin: bool, + reference_force_real: np.ndarray | None = None, + prediction_force_real: np.ndarray | None = None, + reference_force_magnetic: np.ndarray | None = None, + prediction_force_magnetic: np.ndarray | None = None, + reference_hessian: np.ndarray | None = None, + prediction_hessian: np.ndarray | None = None, +) -> None: + """Write energy-type detail outputs after arrays have been aligned.""" + pe = np.concatenate( + ( + np.reshape(reference_energy, [-1, 1]), + np.reshape(prediction_energy, [-1, 1]), + ), + axis=1, + ) + save_txt_file( + detail_path.with_suffix(".e.out"), + pe, + header=f"{system}: data_e pred_e", + append=append_detail, + ) + pe_atom = pe / natoms + save_txt_file( + detail_path.with_suffix(".e_peratom.out"), + pe_atom, + header=f"{system}: data_e pred_e", + append=append_detail, + ) + if not out_put_spin: + pf = np.concatenate( + ( + np.reshape(reference_force, [-1, 3]), + np.reshape(prediction_force, [-1, 3]), + ), + axis=1, + ) + save_txt_file( + detail_path.with_suffix(".f.out"), + pf, + header=f"{system}: data_fx data_fy data_fz pred_fx pred_fy pred_fz", + append=append_detail, + ) + else: + if reference_force_real is None or prediction_force_real is None: + raise ValueError("Spin detail output requires aligned real-atom forces.") + pf_real = np.concatenate( + ( + np.reshape(reference_force_real, [-1, 3]), + np.reshape(prediction_force_real, [-1, 3]), + ), + axis=1, + ) + save_txt_file( + detail_path.with_suffix(".fr.out"), + pf_real, + header=f"{system}: data_fx data_fy data_fz pred_fx pred_fy pred_fz", + append=append_detail, + ) + if (reference_force_magnetic is None) != (prediction_force_magnetic is None): + raise ValueError( + "Spin magnetic detail output requires both reference and prediction forces." + ) + if ( + reference_force_magnetic is not None + and prediction_force_magnetic is not None + ): + pf_mag = np.concatenate( + ( + np.reshape(reference_force_magnetic, [-1, 3]), + np.reshape(prediction_force_magnetic, [-1, 3]), + ), + axis=1, + ) + save_txt_file( + detail_path.with_suffix(".fm.out"), + pf_mag, + header=f"{system}: data_fmx data_fmy data_fmz pred_fmx pred_fmy pred_fmz", + append=append_detail, + ) + if (reference_virial is None) != (prediction_virial is None): + raise ValueError( + "Virial detail output requires both reference and prediction virials." + ) + if reference_virial is not None and prediction_virial is not None: + pv = np.concatenate( + ( + np.reshape(reference_virial, [-1, 9]), + np.reshape(prediction_virial, [-1, 9]), + ), + axis=1, + ) + save_txt_file( + detail_path.with_suffix(".v.out"), + pv, + header=f"{system}: data_vxx data_vxy data_vxz data_vyx data_vyy " + "data_vyz data_vzx data_vzy data_vzz pred_vxx pred_vxy pred_vxz pred_vyx " + "pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz", + append=append_detail, + ) + pv_atom = pv / natoms + save_txt_file( + detail_path.with_suffix(".v_peratom.out"), + pv_atom, + header=f"{system}: data_vxx data_vxy data_vxz data_vyx data_vyy " + "data_vyz data_vzx data_vzy data_vzz pred_vxx pred_vxy pred_vxz pred_vyx " + "pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz", + append=append_detail, + ) + if reference_hessian is not None and prediction_hessian is not None: + hessian_detail = np.concatenate( + ( + reference_hessian.reshape(-1, 1), + prediction_hessian.reshape(-1, 1), + ), + axis=1, + ) + save_txt_file( + detail_path.with_suffix(".h.out"), + hessian_detail, + header=f"{system}: data_h pred_h (3Na*3Na matrix in row-major order)", + append=append_detail, + ) + + def test_ener( dp: "DeepPot", data: DeepmdData, @@ -305,7 +510,7 @@ def test_ener( detail_file: str | None, has_atom_ener: bool, append_detail: bool = False, -) -> tuple[list[np.ndarray], list[int]]: +) -> dict[str, tuple[float, float]]: """Test energy type model. Parameters @@ -327,8 +532,8 @@ def test_ener( Returns ------- - tuple[list[np.ndarray], list[int]] - arrays with results and their shapes + dict[str, tuple[float, float]] + weighted-average-ready metric pairs """ dict_to_return = {} @@ -409,6 +614,9 @@ def test_ener( energy = energy.reshape([numb_test, 1]) force = force.reshape([numb_test, -1]) virial = virial.reshape([numb_test, 9]) + hessian = None + force_m = None + mask_mag = None if dp.has_hessian: hessian = ret[3] hessian = hessian.reshape([numb_test, -1]) @@ -429,75 +637,114 @@ def test_ener( mask_mag = ret[4] mask_mag = mask_mag.reshape([numb_test, -1]) out_put_spin = dp.get_ntypes_spin() != 0 or dp.has_spin + spin_metrics = None + force_r = None + test_force_r = None + test_force_m = None if out_put_spin: - if dp.get_ntypes_spin() != 0: # old tf support for spin - ntypes_real = dp.get_ntypes() - dp.get_ntypes_spin() - nloc = natoms - nloc_real = sum( - [np.count_nonzero(atype == ii) for ii in range(ntypes_real)] + force_r, test_force_r, force_m, test_force_m = _align_spin_force_arrays( + dp=dp, + atype=atype, + natoms=natoms, + prediction_force=force, + reference_force=test_data["force"][:numb_test], + prediction_force_mag=force_m, + reference_force_mag=( + test_data["force_mag"][:numb_test] if "force_mag" in test_data else None + ), + mask_mag=mask_mag, + ) + if find_force_mag == 1 and (force_m is None or test_force_m is None): + raise RuntimeError( + "Spin magnetic force metrics require magnetic force arrays and mask." ) - force_r = np.split( - force, indices_or_sections=[nloc_real * 3, nloc * 3], axis=1 - )[0] - force_m = np.split( - force, indices_or_sections=[nloc_real * 3, nloc * 3], axis=1 - )[1] - test_force_r = np.split( - test_data["force"][:numb_test], - indices_or_sections=[nloc_real * 3, nloc * 3], - axis=1, - )[0] - test_force_m = np.split( + spin_metrics = compute_spin_force_metrics( + force_real_prediction=force_r, + force_real_reference=test_force_r, + force_magnetic_prediction=force_m if find_force_mag == 1 else None, + force_magnetic_reference=test_force_m if find_force_mag == 1 else None, + ) + + energy_metric_input = { + "find_energy": find_energy, + "find_force": find_force if not out_put_spin else 0.0, + "find_virial": find_virial if not out_put_spin else 0.0, + "energy": test_data["energy"][:numb_test], + "force": test_data["force"][:numb_test], + } + energy_metric_prediction = { + "energy": energy, + "force": force, + } + if find_virial == 1 and data.pbc and not out_put_spin: + energy_metric_input["virial"] = test_data["virial"][:numb_test] + energy_metric_prediction["virial"] = virial + shared_metrics = compute_energy_type_metrics( + prediction=energy_metric_prediction, + test_data=energy_metric_input, + natoms=natoms, + has_pbc=data.pbc, + ) + dict_to_return.update( + shared_metrics.as_weighted_average_errors(DP_TEST_WEIGHTED_METRIC_KEYS) + ) + + weighted_force_metrics = None + if find_energy == 1: + if shared_metrics.energy is None or shared_metrics.energy_per_atom is None: + raise RuntimeError("Energy metrics are unavailable for dp test.") + mae_e = shared_metrics.energy.mae + rmse_e = shared_metrics.energy.rmse + mae_ea = shared_metrics.energy_per_atom.mae + rmse_ea = shared_metrics.energy_per_atom.rmse + + if not out_put_spin and find_force == 1: + if shared_metrics.force is None: + raise RuntimeError("Force metrics are unavailable for dp test.") + mae_f = shared_metrics.force.mae + rmse_f = shared_metrics.force.rmse + if find_atom_pref == 1: + weighted_force_metrics = compute_weighted_error_stat( + force, test_data["force"][:numb_test], - indices_or_sections=[nloc_real * 3, nloc * 3], - axis=1, - )[1] - else: # pt support for spin - force_r = force - test_force_r = test_data["force"][:numb_test] - # The shape of force_m and test_force_m are [-1, 3], - # which is designed for mixed_type cases - force_m = force_m.reshape(-1, 3)[mask_mag.reshape(-1)] - test_force_m = test_data["force_mag"][:numb_test].reshape(-1, 3)[ - mask_mag.reshape(-1) - ] + test_data["atom_pref"][:numb_test], + ) + mae_fw = weighted_force_metrics.mae + rmse_fw = weighted_force_metrics.rmse - diff_e = energy - test_data["energy"][:numb_test].reshape([-1, 1]) - mae_e = mae(diff_e) - rmse_e = rmse(diff_e) - diff_f = force - test_data["force"][:numb_test] - mae_f = mae(diff_f) - rmse_f = rmse(diff_f) - size_f = diff_f.size - if find_atom_pref == 1: - atom_weight = test_data["atom_pref"][:numb_test] - weight_sum = np.sum(atom_weight) - if weight_sum > 0: - mae_fw = np.sum(np.abs(diff_f) * atom_weight) / weight_sum - rmse_fw = np.sqrt(np.sum(diff_f * diff_f * atom_weight) / weight_sum) - else: - mae_fw = 0.0 - rmse_fw = 0.0 - diff_v = virial - test_data["virial"][:numb_test] - mae_v = mae(diff_v) - rmse_v = rmse(diff_v) - mae_ea = mae_e / natoms - rmse_ea = rmse_e / natoms - mae_va = mae_v / natoms - rmse_va = rmse_v / natoms + if data.pbc and not out_put_spin and find_virial == 1: + if shared_metrics.virial is None or shared_metrics.virial_per_atom is None: + raise RuntimeError("Virial metrics are unavailable for dp test.") + mae_v = shared_metrics.virial.mae + rmse_v = shared_metrics.virial.rmse + mae_va = shared_metrics.virial_per_atom.mae + rmse_va = shared_metrics.virial_per_atom.rmse + + hessian_metrics = None if dp.has_hessian: - diff_h = hessian - test_data["hessian"][:numb_test] - mae_h = mae(diff_h) - rmse_h = rmse(diff_h) + hessian_metrics = compute_error_stat( + hessian, + test_data["hessian"][:numb_test], + ) + mae_h = hessian_metrics.mae + rmse_h = hessian_metrics.rmse if has_atom_ener: - diff_ae = test_data["atom_ener"][:numb_test].reshape([-1]) - ae.reshape([-1]) - mae_ae = mae(diff_ae) - rmse_ae = rmse(diff_ae) + atomic_energy_metrics = compute_error_stat( + ae.reshape([-1]), + test_data["atom_ener"][:numb_test].reshape([-1]), + ) + mae_ae = atomic_energy_metrics.mae + rmse_ae = atomic_energy_metrics.rmse if out_put_spin: - mae_fr = mae(force_r - test_force_r) - mae_fm = mae(force_m - test_force_m) - rmse_fr = rmse(force_r - test_force_r) - rmse_fm = rmse(force_m - test_force_m) + if spin_metrics is None or spin_metrics.force_real is None: + raise RuntimeError("Spin force metrics are unavailable for dp test.") + mae_fr = spin_metrics.force_real.mae + rmse_fr = spin_metrics.force_real.rmse + if find_force_mag == 1: + if spin_metrics.force_magnetic is None: + raise RuntimeError("Spin magnetic force metrics are unavailable.") + mae_fm = spin_metrics.force_magnetic.mae + rmse_fm = spin_metrics.force_magnetic.rmse log.info(f"# number of test data : {numb_test:d} ") if find_energy == 1: @@ -505,146 +752,76 @@ def test_ener( log.info(f"Energy RMSE : {rmse_e:e} eV") log.info(f"Energy MAE/Natoms : {mae_ea:e} eV") log.info(f"Energy RMSE/Natoms : {rmse_ea:e} eV") - dict_to_return["mae_e"] = (mae_e, energy.size) - dict_to_return["mae_ea"] = (mae_ea, energy.size) - dict_to_return["rmse_e"] = (rmse_e, energy.size) - dict_to_return["rmse_ea"] = (rmse_ea, energy.size) if not out_put_spin and find_force == 1: log.info(f"Force MAE : {mae_f:e} eV/Å") log.info(f"Force RMSE : {rmse_f:e} eV/Å") - dict_to_return["mae_f"] = (mae_f, size_f) - dict_to_return["rmse_f"] = (rmse_f, size_f) - if find_atom_pref == 1: + if weighted_force_metrics is not None: log.info(f"Force weighted MAE : {mae_fw:e} eV/Å") log.info(f"Force weighted RMSE: {rmse_fw:e} eV/Å") - dict_to_return["mae_fw"] = (mae_fw, weight_sum) - dict_to_return["rmse_fw"] = (rmse_fw, weight_sum) + dict_to_return.update( + weighted_force_metrics.as_weighted_average_errors( + *DP_TEST_WEIGHTED_FORCE_METRIC_KEYS + ) + ) if out_put_spin and find_force == 1: log.info(f"Force atom MAE : {mae_fr:e} eV/Å") log.info(f"Force atom RMSE : {rmse_fr:e} eV/Å") - dict_to_return["mae_fr"] = (mae_fr, force_r.size) - dict_to_return["rmse_fr"] = (rmse_fr, force_r.size) + dict_to_return.update( + spin_metrics.as_weighted_average_errors( + {"force_real": DP_TEST_SPIN_WEIGHTED_METRIC_KEYS["force_real"]} + ) + ) if out_put_spin and find_force_mag == 1: log.info(f"Force spin MAE : {mae_fm:e} eV/uB") log.info(f"Force spin RMSE : {rmse_fm:e} eV/uB") - dict_to_return["mae_fm"] = (mae_fm, force_m.size) - dict_to_return["rmse_fm"] = (rmse_fm, force_m.size) + dict_to_return.update( + spin_metrics.as_weighted_average_errors( + {"force_magnetic": DP_TEST_SPIN_WEIGHTED_METRIC_KEYS["force_magnetic"]} + ) + ) if data.pbc and not out_put_spin and find_virial == 1: log.info(f"Virial MAE : {mae_v:e} eV") log.info(f"Virial RMSE : {rmse_v:e} eV") log.info(f"Virial MAE/Natoms : {mae_va:e} eV") log.info(f"Virial RMSE/Natoms : {rmse_va:e} eV") - dict_to_return["mae_v"] = (mae_v, virial.size) - dict_to_return["mae_va"] = (mae_va, virial.size) - dict_to_return["rmse_v"] = (rmse_v, virial.size) - dict_to_return["rmse_va"] = (rmse_va, virial.size) if has_atom_ener: log.info(f"Atomic ener MAE : {mae_ae:e} eV") log.info(f"Atomic ener RMSE : {rmse_ae:e} eV") if dp.has_hessian: log.info(f"Hessian MAE : {mae_h:e} eV/Å^2") log.info(f"Hessian RMSE : {rmse_h:e} eV/Å^2") - dict_to_return["mae_h"] = (mae_h, hessian.size) - dict_to_return["rmse_h"] = (rmse_h, hessian.size) + if hessian_metrics is None: + raise RuntimeError("Hessian metrics are unavailable for dp test.") + dict_to_return.update( + hessian_metrics.as_weighted_average_errors(*DP_TEST_HESSIAN_METRIC_KEYS) + ) if detail_file is not None: - detail_path = Path(detail_file) - - pe = np.concatenate( - ( - np.reshape(test_data["energy"][:numb_test], [-1, 1]), - np.reshape(energy, [-1, 1]), - ), - axis=1, - ) - save_txt_file( - detail_path.with_suffix(".e.out"), - pe, - header=f"{system}: data_e pred_e", - append=append_detail, - ) - pe_atom = pe / natoms - save_txt_file( - detail_path.with_suffix(".e_peratom.out"), - pe_atom, - header=f"{system}: data_e pred_e", - append=append_detail, + _write_energy_test_details( + detail_path=Path(detail_file), + system=system, + natoms=natoms, + append_detail=append_detail, + reference_energy=test_data["energy"][:numb_test], + prediction_energy=energy, + reference_force=test_data["force"][:numb_test], + prediction_force=force, + reference_virial=test_data["virial"][:numb_test] + if find_virial == 1 and data.pbc + else None, + prediction_virial=virial if find_virial == 1 and data.pbc else None, + out_put_spin=out_put_spin, + reference_force_real=test_force_r, + prediction_force_real=force_r, + reference_force_magnetic=test_force_m if find_force_mag == 1 else None, + prediction_force_magnetic=force_m + if out_put_spin and find_force_mag == 1 + else None, + reference_hessian=test_data["hessian"][:numb_test] + if dp.has_hessian + else None, + prediction_hessian=hessian if dp.has_hessian else None, ) - if not out_put_spin: - pf = np.concatenate( - ( - np.reshape(test_data["force"][:numb_test], [-1, 3]), - np.reshape(force, [-1, 3]), - ), - axis=1, - ) - save_txt_file( - detail_path.with_suffix(".f.out"), - pf, - header=f"{system}: data_fx data_fy data_fz pred_fx pred_fy pred_fz", - append=append_detail, - ) - else: - pf_real = np.concatenate( - (np.reshape(test_force_r, [-1, 3]), np.reshape(force_r, [-1, 3])), - axis=1, - ) - pf_mag = np.concatenate( - (np.reshape(test_force_m, [-1, 3]), np.reshape(force_m, [-1, 3])), - axis=1, - ) - save_txt_file( - detail_path.with_suffix(".fr.out"), - pf_real, - header=f"{system}: data_fx data_fy data_fz pred_fx pred_fy pred_fz", - append=append_detail, - ) - save_txt_file( - detail_path.with_suffix(".fm.out"), - pf_mag, - header=f"{system}: data_fmx data_fmy data_fmz pred_fmx pred_fmy pred_fmz", - append=append_detail, - ) - pv = np.concatenate( - ( - np.reshape(test_data["virial"][:numb_test], [-1, 9]), - np.reshape(virial, [-1, 9]), - ), - axis=1, - ) - save_txt_file( - detail_path.with_suffix(".v.out"), - pv, - header=f"{system}: data_vxx data_vxy data_vxz data_vyx data_vyy " - "data_vyz data_vzx data_vzy data_vzz pred_vxx pred_vxy pred_vxz pred_vyx " - "pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz", - append=append_detail, - ) - pv_atom = pv / natoms - save_txt_file( - detail_path.with_suffix(".v_peratom.out"), - pv_atom, - header=f"{system}: data_vxx data_vxy data_vxz data_vyx data_vyy " - "data_vyz data_vzx data_vzy data_vzz pred_vxx pred_vxy pred_vxz pred_vyx " - "pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz", - append=append_detail, - ) - if dp.has_hessian: - data_h = test_data["hessian"][:numb_test].reshape(-1, 1) - pred_h = hessian.reshape(-1, 1) - h = np.concatenate( - ( - data_h, - pred_h, - ), - axis=1, - ) - save_txt_file( - detail_path.with_suffix(".h.out"), - h, - header=f"{system}: data_h pred_h (3Na*3Na matrix in row-major order)", - append=append_detail, - ) return dict_to_return @@ -669,9 +846,10 @@ def print_ener_sys_avg(avg: dict[str, float]) -> None: log.info(f"Force weighted RMSE: {avg['rmse_fw']:e} eV/Å") else: log.info(f"Force atom MAE : {avg['mae_fr']:e} eV/Å") - log.info(f"Force spin MAE : {avg['mae_fm']:e} eV/uB") log.info(f"Force atom RMSE : {avg['rmse_fr']:e} eV/Å") - log.info(f"Force spin RMSE : {avg['rmse_fm']:e} eV/uB") + if "rmse_fm" in avg: + log.info(f"Force spin MAE : {avg['mae_fm']:e} eV/uB") + log.info(f"Force spin RMSE : {avg['rmse_fm']:e} eV/uB") if "rmse_v" in avg: log.info(f"Virial MAE : {avg['mae_v']:e} eV") log.info(f"Virial RMSE : {avg['rmse_v']:e} eV") diff --git a/deepmd/jax/utils/auto_batch_size.py b/deepmd/jax/utils/auto_batch_size.py index 1ecf020086..df526a732b 100644 --- a/deepmd/jax/utils/auto_batch_size.py +++ b/deepmd/jax/utils/auto_batch_size.py @@ -24,10 +24,13 @@ def __init__( self, initial_batch_size: int = 1024, factor: float = 2.0, + *, + silent: bool = False, ) -> None: super().__init__( initial_batch_size=initial_batch_size, factor=factor, + silent=silent, ) def is_gpu_available(self) -> bool: diff --git a/deepmd/pd/utils/auto_batch_size.py b/deepmd/pd/utils/auto_batch_size.py index f45746ed95..de557df3b6 100644 --- a/deepmd/pd/utils/auto_batch_size.py +++ b/deepmd/pd/utils/auto_batch_size.py @@ -22,10 +22,13 @@ def __init__( self, initial_batch_size: int = 1024, factor: float = 2.0, + *, + silent: bool = False, ) -> None: super().__init__( initial_batch_size=initial_batch_size, factor=factor, + silent=silent, ) def is_gpu_available(self) -> bool: diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 46ad8a6cd0..70c90070fa 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -108,6 +108,7 @@ def get_trainer( finetune_links: dict[str, Any] | None = None, ) -> training.Trainer: multi_task = "model_dict" in config.get("model", {}) + config = normalize(config, multi_task=multi_task, check=False) def prepare_trainer_input_single( model_params_single: dict[str, Any], diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 8d16e1c7ea..7070c132f9 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -54,6 +54,10 @@ KFOptimizerWrapper, LKFOptimizer, ) +from deepmd.pt.train.validation import ( + FullValidator, + resolve_full_validation_start_step, +) from deepmd.pt.train.wrapper import ( ModelWrapper, ) @@ -857,11 +861,89 @@ def single_model_finetune( self.enable_profiler = training_params.get("enable_profiler", False) self.profiling = training_params.get("profiling", False) self.profiling_file = training_params.get("profiling_file", "timeline.json") + validating_params = config.get("validating") or {} + self.full_validator = self._create_full_validator( + validating_params=validating_params, + validation_data=validation_data, + ) # Log model parameter count if self.rank == 0: self._log_parameter_count() + def _create_full_validator( + self, + *, + validating_params: dict[str, Any], + validation_data: DpLoaderSet | None, + ) -> FullValidator | None: + """Create the runtime full validator when it is active.""" + if not self._is_full_validation_requested(validating_params): + return None + self._raise_if_full_validation_unsupported(validation_data) + if validation_data is None: + raise RuntimeError( + "validation_data must be available after full validation checks." + ) + return FullValidator( + validating_params=validating_params, + validation_data=validation_data, + model=self.model, + train_infos=self._get_inner_module().train_infos, + num_steps=self.num_steps, + rank=self.rank, + zero_stage=self.zero_stage, + restart_training=self.restart_training, + checkpoint_dir=Path(self.save_ckpt).parent, + ) + + def _is_full_validation_requested(self, validating_params: dict[str, Any]) -> bool: + """Check whether full validation can trigger during this training run.""" + if not validating_params.get("full_validation", False): + return False + start_step = resolve_full_validation_start_step( + validating_params.get("full_val_start", 0.5), + self.num_steps, + ) + return start_step is not None and start_step <= self.num_steps + + def _raise_if_full_validation_unsupported( + self, + validation_data: DpLoaderSet | None, + ) -> None: + """Validate runtime full validation constraints.""" + if self.multi_task: + raise ValueError( + "validating.full_validation only supports single-task energy " + "training; multi-task training is not supported." + ) + + has_spin = getattr(self.model, "has_spin", False) + if callable(has_spin): + has_spin = has_spin() + if has_spin or isinstance(self.loss, EnergySpinLoss): + raise ValueError( + "validating.full_validation only supports single-task energy " + "training; spin-energy training is not supported." + ) + + if not isinstance(self.loss, EnergyStdLoss): + raise ValueError( + "validating.full_validation only supports single-task energy training." + ) + + if validation_data is None: + raise ValueError( + "validating.full_validation requires `training.validation_data` " + "to be configured." + ) + + if self.zero_stage >= 2: + raise ValueError( + "validating.full_validation only supports single-task energy " + "training with training.zero_stage < 2." + ) + @staticmethod def _count_parameters(model: torch.nn.Module) -> tuple[int, int]: """ @@ -1363,6 +1445,14 @@ def log_loss_valid(_task_key: str = "Default") -> dict: fout, display_step_id, cur_lr, train_results, valid_results ) + if self.full_validator is not None: + self.full_validator.run( + step_id=_step_id, + display_step=display_step_id, + lr=cur_lr, + save_checkpoint=self.save_model, + ) + if ( ( (display_step_id) % self.save_freq == 0 diff --git a/deepmd/pt/train/validation.py b/deepmd/pt/train/validation.py new file mode 100644 index 0000000000..c38e3b584a --- /dev/null +++ b/deepmd/pt/train/validation.py @@ -0,0 +1,713 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later + +from __future__ import ( + annotations, +) + +import logging +import re +import traceback +from dataclasses import ( + dataclass, +) +from pathlib import ( + Path, +) +from typing import ( + TYPE_CHECKING, + Any, +) + +import numpy as np +import torch +import torch.distributed as dist + +from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT +from deepmd.pt.utils.auto_batch_size import ( + AutoBatchSize, +) +from deepmd.pt.utils.dataset import ( + DeepmdDataSetForLoader, +) +from deepmd.pt.utils.env import ( + DEVICE, + GLOBAL_PT_FLOAT_PRECISION, + RESERVED_PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + to_torch_tensor, +) +from deepmd.utils.argcheck import ( + normalize_full_validation_metric, + resolve_full_validation_start_step, +) +from deepmd.utils.eval_metrics import ( + FULL_VALIDATION_METRIC_FAMILY_BY_KEY, + FULL_VALIDATION_METRIC_KEY_MAP, + FULL_VALIDATION_WEIGHTED_METRIC_KEYS, + compute_energy_type_metrics, +) +from deepmd.utils.weight_avg import ( + weighted_average, +) + +log = logging.getLogger(__name__) + +if TYPE_CHECKING: + from deepmd.utils.data import ( + DeepmdData, + ) + +LOG_COLUMN_ORDER = [ + ("E_MAE", "mae_e_per_atom"), + ("E_RMSE", "rmse_e_per_atom"), + ("F_MAE", "mae_f"), + ("F_RMSE", "rmse_f"), + ("V_MAE", "mae_v_per_atom"), + ("V_RMSE", "rmse_v_per_atom"), +] + +TOPK_RECORDS_INFO_KEY = "full_validation_topk_records" +BEST_METRIC_NAME_INFO_KEY = "full_validation_metric" +BEST_CKPT_GLOB = "best.ckpt-*.t-*.pt" +BEST_CKPT_PATTERN = re.compile(r"^best\.ckpt-(\d+)\.t-(\d+)\.pt$") +STALE_FULL_VALIDATION_INFO_KEYS = ( + "full_validation_best_metric", + "full_validation_best_step", + "full_validation_best_path", + "full_validation_best_records", +) +VAL_LOG_SIGNIFICANT_DIGITS = 5 +VAL_LOG_COLUMN_GAP = " " +VAL_LOG_HEADER_PREFIX = "# " +VAL_LOG_DATA_PREFIX = " " +METRIC_LOG_UNIT_MAP = { + "e": ("meV/atom", 1000.0), + "f": ("meV/Å", 1000.0), + "v": ("meV/atom", 1000.0), +} + + +@dataclass(frozen=True) +class FullValidationResult: + """Result of one full validation run.""" + + display_step: int + metrics: dict[str, float] + selected_metric_key: str + selected_metric_value: float + saved_best_path: str | None + + +@dataclass(order=True, frozen=True) +class BestCheckpointRecord: + """One best-checkpoint record ordered by metric then step.""" + + metric: float + step: int + + +def parse_validation_metric(metric: str) -> tuple[str, str]: + """Parse the configured full validation metric.""" + normalized_metric = normalize_full_validation_metric(metric) + if normalized_metric not in FULL_VALIDATION_METRIC_KEY_MAP: + supported_metrics = ", ".join( + item.upper() for item in FULL_VALIDATION_METRIC_KEY_MAP + ) + raise ValueError( + "validating.validation_metric must be one of " + f"{supported_metrics}, got {metric!r}." + ) + return normalized_metric, FULL_VALIDATION_METRIC_KEY_MAP[normalized_metric] + + +def format_metric_for_log( + metric_name: str, metric_value: float +) -> tuple[str, float, str]: + """Format a full validation metric for user-facing logging.""" + metric_family, metric_kind = metric_name.split(":") + metric_unit, metric_scale = METRIC_LOG_UNIT_MAP[metric_family] + metric_label = f"{metric_family.upper()}:{metric_kind.upper()}" + return metric_label, metric_value * metric_scale, metric_unit + + +def format_metric_value_for_table( + metric_key: str, metric_value: float +) -> tuple[float, str]: + """Format one table metric value and its unit for `val.log`.""" + metric_family = FULL_VALIDATION_METRIC_FAMILY_BY_KEY.get(metric_key) + if metric_family is None: + raise ValueError(f"Unknown full validation metric key: {metric_key}") + metric_unit, metric_scale = METRIC_LOG_UNIT_MAP[metric_family] + return metric_value * metric_scale, metric_unit + + +def format_metric_number_for_log(metric_value: float) -> str: + """Format one metric value for `val.log` and best-save messages.""" + if np.isnan(metric_value): + return "nan" + if np.isposinf(metric_value): + return "inf" + if np.isneginf(metric_value): + return "-inf" + if metric_value == 0.0: + return "0" + abs_value = abs(metric_value) + if abs_value < np.finfo(float).tiny: + return "0" + decimals = VAL_LOG_SIGNIFICANT_DIGITS - int(np.floor(np.log10(abs_value))) - 1 + if decimals > 16: + return f"{metric_value:.{VAL_LOG_SIGNIFICANT_DIGITS - 1}e}" + rounded_value = round(metric_value, decimals) + if rounded_value == 0.0: + rounded_value = 0.0 + if decimals > 0: + return f"{rounded_value:.{decimals}f}" + return f"{rounded_value:.0f}" + + +class FullValidator: + """Run independent full validation during training.""" + + def __init__( + self, + *, + validating_params: dict[str, Any], + validation_data: Any, + model: torch.nn.Module, + train_infos: dict[str, Any], + num_steps: int, + rank: int, + zero_stage: int, + restart_training: bool, + checkpoint_dir: Path | None = None, + ) -> None: + self.validation_data = validation_data + self.model = model + self.train_infos = train_infos + self.rank = rank + self.zero_stage = zero_stage + self.checkpoint_dir = ( + Path(checkpoint_dir) if checkpoint_dir is not None else Path(".") + ) + self.is_distributed = dist.is_available() and dist.is_initialized() + + self.full_validation = bool(validating_params.get("full_validation", False)) + self.validation_freq = int(validating_params.get("validation_freq", 5000)) + self.save_best = bool(validating_params.get("save_best", True)) + self.max_best_ckpt = int(validating_params.get("max_best_ckpt", 1)) + self.metric_name, self.metric_key = parse_validation_metric( + str(validating_params.get("validation_metric", "E:MAE")) + ) + self.full_val_file = Path(validating_params.get("full_val_file", "val.log")) + self.start_step = resolve_full_validation_start_step( + validating_params.get("full_val_start", 0.5), + num_steps, + ) + self.enabled = ( + self.full_validation + and self.start_step is not None + and self.start_step <= num_steps + ) + self.step_column_width = max(len("step"), len(str(num_steps))) + self._write_mode = "a" if restart_training else "w" + self._should_write_header = not ( + restart_training and self.full_val_file.exists() + ) + self.auto_batch_size = AutoBatchSize(silent=True) + self.table_column_specs = [] + for column_name, metric_key in LOG_COLUMN_ORDER: + _, metric_unit = format_metric_value_for_table(metric_key, 1.0) + header_label = f"{column_name}({metric_unit})" + self.table_column_specs.append( + (metric_key, header_label, max(len(header_label), 18)) + ) + + self.topk_records = self._load_topk_records() + self._sync_train_infos() + if self.rank == 0: + self._initialize_best_checkpoints(restart_training=restart_training) + + def should_run(self, display_step: int) -> bool: + """Check whether the current step should trigger full validation.""" + if not self.enabled or self.start_step is None: + return False + if display_step < self.start_step: + return False + return (display_step - self.start_step) % self.validation_freq == 0 + + def run( + self, + *, + step_id: int, + display_step: int, + lr: float, + save_checkpoint: Any, + ) -> FullValidationResult | None: + """Run full validation if the current step is due.""" + if not self.should_run(display_step): + return None + + if self.is_distributed: + dist.barrier() + + result: FullValidationResult | None = None + caught_exception: Exception | None = None + error_message = None + save_path = [None] + if self.rank == 0: + try: + result = self._evaluate(display_step) + save_path[0] = result.saved_best_path + except Exception as exc: + caught_exception = exc + error_message = ( + "Full validation failed on rank 0 during evaluation:\n" + f"{traceback.format_exc()}" + ) + + self._raise_if_distributed_error(error_message, caught_exception) + + if self.is_distributed: + dist.broadcast_object_list(save_path, src=0) + + if save_path[0] is not None: + try: + # ZeRO/FSDP checkpoint collection is collective, so all ranks must + # enter `save_checkpoint` whenever `zero_stage > 0`. + if (self.is_distributed and self.zero_stage != 0) or self.rank == 0: + save_checkpoint(Path(save_path[0]), lr=lr, step=step_id) + if self.rank == 0: + self._reconcile_best_checkpoints() + except Exception as exc: + caught_exception = exc + error_message = ( + "Full validation failed while saving the best checkpoint:\n" + f"{traceback.format_exc()}" + ) + else: + error_message = None + caught_exception = None + + self._raise_if_distributed_error(error_message, caught_exception) + + if self.rank == 0: + try: + self._log_result(result) + except Exception as exc: + caught_exception = exc + error_message = ( + "Full validation failed while writing logs:\n" + f"{traceback.format_exc()}" + ) + else: + error_message = None + caught_exception = None + + self._raise_if_distributed_error(error_message, caught_exception) + + if self.is_distributed: + dist.barrier() + + return result if self.rank == 0 else None + + def _evaluate(self, display_step: int) -> FullValidationResult: + """Evaluate all validation systems and update best state.""" + # === Step 1. Switch to Evaluation Mode === + was_training = bool(getattr(self.model, "training", True)) + self.model.eval() + try: + # === Step 2. Evaluate All Systems === + metrics = self.evaluate_all_systems() + finally: + self.model.train(was_training) + + if self.metric_key not in metrics or np.isnan(metrics[self.metric_key]): + raise RuntimeError( + "The selected full validation metric is unavailable on the " + f"validation dataset: {self.metric_name.upper()}." + ) + + # === Step 3. Update Best Tracking === + selected_metric_value = float(metrics[self.metric_key]) + saved_best_path = self._update_best_state( + display_step=display_step, + selected_metric_value=selected_metric_value, + ) + return FullValidationResult( + display_step=display_step, + metrics=metrics, + selected_metric_key=self.metric_key, + selected_metric_value=selected_metric_value, + saved_best_path=saved_best_path, + ) + + def evaluate_all_systems(self) -> dict[str, float]: + """Evaluate every validation system and aggregate metrics.""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + system_metrics = [] + for dataset in self.validation_data.systems: + if not isinstance(dataset, DeepmdDataSetForLoader): + raise TypeError( + "Full validation expects each dataset in validation_data.systems " + f"to be DeepmdDataSetForLoader, got {type(dataset)!r}." + ) + system_metrics.append(self._evaluate_system(dataset.data_system)) + + aggregated = weighted_average([metric for metric in system_metrics if metric]) + return { + metric_key: float(aggregated[metric_key]) + for _, metric_key in LOG_COLUMN_ORDER + if metric_key in aggregated + } + + def _evaluate_system( + self, data_system: DeepmdData + ) -> dict[str, tuple[float, float]]: + """Evaluate one validation system.""" + test_data = data_system.get_test() + natoms = int(test_data["type"].shape[1]) + nframes = int(test_data["coord"].shape[0]) + include_virial = data_system.pbc and bool(test_data.get("find_virial", 0.0)) + prediction = self._predict_outputs( + coord=test_data["coord"].reshape(nframes, -1), + atom_types=test_data["type"], + box=test_data["box"] if data_system.pbc else None, + fparam=test_data["fparam"] + if bool(test_data.get("find_fparam", 0.0)) + else None, + aparam=test_data["aparam"] if self.model.get_dim_aparam() > 0 else None, + include_virial=include_virial, + natoms=natoms, + nframes=nframes, + ) + shared_metrics = compute_energy_type_metrics( + prediction=prediction, + test_data=test_data, + natoms=natoms, + has_pbc=data_system.pbc, + ) + return shared_metrics.as_weighted_average_errors( + FULL_VALIDATION_WEIGHTED_METRIC_KEYS + ) + + def _predict_outputs( + self, + *, + coord: np.ndarray, + atom_types: np.ndarray, + box: np.ndarray | None, + fparam: np.ndarray | None, + aparam: np.ndarray | None, + include_virial: bool, + natoms: int, + nframes: int, + ) -> dict[str, np.ndarray]: + """Predict energy, force, and virial for the full validation batch.""" + + def predict_batch( + coord_batch: np.ndarray, + atom_types_batch: np.ndarray, + box_batch: np.ndarray | None, + fparam_batch: np.ndarray | None, + aparam_batch: np.ndarray | None, + ) -> dict[str, np.ndarray]: + coord_input = torch.tensor( + coord_batch.reshape(-1, natoms, 3).astype( + NP_PRECISION_DICT[ + RESERVED_PRECISION_DICT[GLOBAL_PT_FLOAT_PRECISION] + ] + ), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + type_input = torch.tensor( + atom_types_batch.astype(np.int64), + dtype=torch.long, + device=DEVICE, + ) + if box_batch is not None: + box_input = torch.tensor( + box_batch.reshape(-1, 3, 3).astype( + NP_PRECISION_DICT[ + RESERVED_PRECISION_DICT[GLOBAL_PT_FLOAT_PRECISION] + ] + ), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + else: + box_input = None + if fparam_batch is not None: + fparam_input = to_torch_tensor( + fparam_batch.reshape(-1, self.model.get_dim_fparam()) + ) + else: + fparam_input = None + if aparam_batch is not None: + aparam_input = to_torch_tensor( + aparam_batch.reshape(-1, natoms, self.model.get_dim_aparam()) + ) + else: + aparam_input = None + + # Do not use `torch.no_grad()` here: force/virial predictions rely on + # autograd inside the model even during evaluation. + batch_output = self.model( + coord_input, + type_input, + box=box_input, + fparam=fparam_input, + aparam=aparam_input, + ) + if isinstance(batch_output, tuple): + batch_output = batch_output[0] + + prediction = { + "energy": batch_output["energy"].detach().cpu().numpy().reshape(-1, 1), + "force": batch_output["force"] + .detach() + .cpu() + .numpy() + .reshape(-1, natoms * 3), + } + if include_virial: + if "virial" not in batch_output: + raise KeyError( + "Full validation requested virial metrics, but model " + "output does not contain `virial`." + ) + prediction["virial"] = ( + batch_output["virial"].detach().cpu().numpy().reshape(-1, 9) + ) + return prediction + + batch_prediction = self.auto_batch_size.execute_all( + predict_batch, + nframes, + natoms, + coord, + atom_types, + box, + fparam, + aparam, + ) + prediction = { + "energy": batch_prediction["energy"], + "force": batch_prediction["force"], + } + if include_virial: + prediction["virial"] = batch_prediction["virial"] + return prediction + + def _update_best_state( + self, + *, + display_step: int, + selected_metric_value: float, + ) -> str | None: + """Update the top-K records and return the checkpoint path to save.""" + candidate = BestCheckpointRecord( + metric=selected_metric_value, + step=display_step, + ) + updated_records = [ + record for record in self.topk_records if record.step != display_step + ] + updated_records.append(candidate) + updated_records.sort() + updated_records = updated_records[: self.max_best_ckpt] + if candidate not in updated_records: + return None + + self.topk_records = updated_records + self._sync_train_infos() + if not self.save_best: + return None + candidate_rank = self.topk_records.index(candidate) + 1 + return str(self._best_checkpoint_path(display_step, candidate_rank)) + + def _sync_train_infos(self) -> None: + """Synchronize top-K validation state into train infos.""" + for key in STALE_FULL_VALIDATION_INFO_KEYS: + self.train_infos.pop(key, None) + self.train_infos[BEST_METRIC_NAME_INFO_KEY] = self.metric_name + self.train_infos[TOPK_RECORDS_INFO_KEY] = [ + {"metric": record.metric, "step": record.step} + for record in self.topk_records + ] + + def _load_topk_records(self) -> list[BestCheckpointRecord]: + """Load top-K records from train infos for the current metric.""" + if self.train_infos.get(BEST_METRIC_NAME_INFO_KEY) != self.metric_name: + return [] + raw_records = self.train_infos.get(TOPK_RECORDS_INFO_KEY, []) + if not isinstance(raw_records, list): + return [] + records = [] + for raw_record in raw_records: + if not isinstance(raw_record, dict): + continue + if "metric" not in raw_record or "step" not in raw_record: + continue + records.append( + BestCheckpointRecord( + metric=float(raw_record["metric"]), + step=int(raw_record["step"]), + ) + ) + records.sort() + return records[: self.max_best_ckpt] + + def _best_checkpoint_name(self, step: int, rank: int) -> str: + """Build the best-checkpoint filename for one step.""" + return f"best.ckpt-{step}.t-{rank}.pt" + + def _best_checkpoint_path(self, step: int, rank: int) -> Path: + """Build the best-checkpoint path for one step.""" + return self.checkpoint_dir / self._best_checkpoint_name(step, rank) + + def _list_best_checkpoints(self) -> list[Path]: + """List all managed best checkpoints in the checkpoint directory.""" + best_checkpoints = [ + path + for path in self.checkpoint_dir.glob(BEST_CKPT_GLOB) + if path.is_file() and not path.is_symlink() + ] + best_checkpoints.sort(key=lambda path: path.stat().st_mtime) + return best_checkpoints + + def _expected_topk_checkpoint_names(self) -> dict[int, str]: + """Return the expected checkpoint filename for each retained step.""" + return { + record.step: self._best_checkpoint_name(record.step, rank) + for rank, record in enumerate(self.topk_records, start=1) + } + + def _reconcile_best_checkpoints(self) -> None: + """Rename retained best checkpoints to ranked names and delete stale ones.""" + expected_names = self._expected_topk_checkpoint_names() + current_files = self._list_best_checkpoints() + files_by_step: dict[int, list[Path]] = {} + stale_files: list[Path] = [] + for checkpoint_path in current_files: + match = BEST_CKPT_PATTERN.match(checkpoint_path.name) + if match is None: + stale_files.append(checkpoint_path) + continue + step = int(match.group(1)) + files_by_step.setdefault(step, []).append(checkpoint_path) + + temp_moves: list[tuple[Path, Path]] = [] + for step, checkpoint_paths in files_by_step.items(): + expected_name = expected_names.get(step) + if expected_name is None: + stale_files.extend(checkpoint_paths) + continue + + keep_path = next( + ( + checkpoint_path + for checkpoint_path in checkpoint_paths + if checkpoint_path.name == expected_name + ), + checkpoint_paths[0], + ) + for checkpoint_path in checkpoint_paths: + if checkpoint_path != keep_path: + stale_files.append(checkpoint_path) + if keep_path.name != expected_name: + temp_path = keep_path.with_name(f"{keep_path.name}.tmp") + keep_path.rename(temp_path) + temp_moves.append((temp_path, keep_path.with_name(expected_name))) + + for checkpoint_path in stale_files: + checkpoint_path.unlink(missing_ok=True) + for temp_path, final_path in temp_moves: + final_path.unlink(missing_ok=True) + temp_path.rename(final_path) + + def _initialize_best_checkpoints(self, restart_training: bool) -> None: + """Align on-disk best checkpoints with the current training mode.""" + if restart_training and self.save_best and self.topk_records: + self._reconcile_best_checkpoints() + return + for checkpoint_path in self._list_best_checkpoints(): + checkpoint_path.unlink(missing_ok=True) + + def _raise_if_distributed_error( + self, + local_error_message: str | None, + local_exception: Exception | None = None, + ) -> None: + """Propagate a local error to all ranks and raise consistently.""" + error_message = local_error_message + if self.is_distributed: + gathered_errors = [None] * dist.get_world_size() + dist.all_gather_object(gathered_errors, local_error_message) + error_message = next( + (message for message in gathered_errors if message is not None), None + ) + if error_message is None: + return + if local_exception is not None: + raise RuntimeError(error_message) from local_exception + raise RuntimeError(error_message) + + def _log_result(self, result: FullValidationResult | None) -> None: + """Log and persist full validation results on rank 0.""" + if result is None: + raise ValueError("Full validation logging requires a result on rank 0.") + self._write_log_file(result) + if result.saved_best_path is not None: + metric_label, metric_value, metric_unit = format_metric_for_log( + self.metric_name, result.selected_metric_value + ) + log.info( + f"Saved best model to {result.saved_best_path} " + f"with {metric_label} = {format_metric_number_for_log(metric_value)} " + f"{metric_unit}" + ) + + def _write_log_file(self, result: FullValidationResult) -> None: + """Append one full validation entry to the dedicated log file.""" + with self.full_val_file.open(self._write_mode, buffering=1) as fout: + if self._should_write_header: + header = VAL_LOG_HEADER_PREFIX + f"{'step':^{self.step_column_width}s}" + for _, header_label, column_width in self.table_column_specs: + header += VAL_LOG_COLUMN_GAP + f"{header_label:^{column_width}s}" + header += "\n" + header += ( + "# E uses per-atom energy, F uses component-wise force errors, " + "and V uses virial normalized by natoms.\n" + ) + fout.write(header) + self._should_write_header = False + self._write_mode = "a" + + line = ( + VAL_LOG_DATA_PREFIX + + f"{result.display_step:^{self.step_column_width}d}" + ) + for metric_key, _, column_width in self.table_column_specs: + metric_value = result.metrics.get(metric_key, float("nan")) + if not np.isnan(metric_value): + metric_value, _ = format_metric_value_for_table( + metric_key, metric_value + ) + metric_text = format_metric_number_for_log(metric_value) + line += VAL_LOG_COLUMN_GAP + f"{metric_text:^{column_width}s}" + line += "\n" + fout.write(line) + if result.saved_best_path is not None: + metric_label, metric_value, metric_unit = format_metric_for_log( + self.metric_name, result.selected_metric_value + ) + fout.write( + "# saved best checkpoint: " + f"{result.saved_best_path} ({metric_label} = " + f"{format_metric_number_for_log(metric_value)} {metric_unit})\n" + ) diff --git a/deepmd/pt/utils/auto_batch_size.py b/deepmd/pt/utils/auto_batch_size.py index 368f7808ab..5f8e0930d3 100644 --- a/deepmd/pt/utils/auto_batch_size.py +++ b/deepmd/pt/utils/auto_batch_size.py @@ -22,10 +22,13 @@ def __init__( self, initial_batch_size: int = 1024, factor: float = 2.0, + *, + silent: bool = False, ) -> None: super().__init__( initial_batch_size=initial_batch_size, factor=factor, + silent=silent, ) def is_gpu_available(self) -> bool: diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index ce9a6c52c6..20a76a0e87 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -49,6 +49,11 @@ def __init__( def __len__(self) -> int: return self._data_system.nframes + @property + def data_system(self) -> DeepmdData: + """Expose the underlying DeePMD data system.""" + return self._data_system + def __getitem__(self, index: int) -> dict[str, Any]: """Get a frame from the selected system.""" b_data = self._data_system.get_item_torch(index, max(1, NUM_WORKERS)) diff --git a/deepmd/tf/utils/batch_size.py b/deepmd/tf/utils/batch_size.py index 438bf36703..13f69c84dc 100644 --- a/deepmd/tf/utils/batch_size.py +++ b/deepmd/tf/utils/batch_size.py @@ -19,10 +19,16 @@ class AutoBatchSize(AutoBatchSizeBase): - def __init__(self, initial_batch_size: int = 1024, factor: float = 2.0) -> None: - super().__init__(initial_batch_size, factor) + def __init__( + self, + initial_batch_size: int = 1024, + factor: float = 2.0, + *, + silent: bool = False, + ) -> None: + super().__init__(initial_batch_size, factor, silent=silent) DP_INFER_BATCH_SIZE = int(os.environ.get("DP_INFER_BATCH_SIZE", 0)) - if not DP_INFER_BATCH_SIZE > 0: + if not DP_INFER_BATCH_SIZE > 0 and not self.silent: if self.is_gpu_available(): log.info( "If you encounter the error 'an illegal memory access was encountered', this may be due to a TensorFlow issue. " diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 5fe1d4f3f1..ddb3f290df 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -4144,6 +4144,222 @@ def training_extra_check(data: dict | None) -> bool: ) +FULL_VALIDATION_METRIC_PREFS = { + "e:mae": ("start_pref_e", "limit_pref_e"), + "e:rmse": ("start_pref_e", "limit_pref_e"), + "f:mae": ("start_pref_f", "limit_pref_f"), + "f:rmse": ("start_pref_f", "limit_pref_f"), + "v:mae": ("start_pref_v", "limit_pref_v"), + "v:rmse": ("start_pref_v", "limit_pref_v"), +} + + +def normalize_full_validation_metric(metric: str) -> str: + """Normalize the full validation metric string.""" + return metric.strip().lower() + + +def is_valid_full_validation_metric(metric: str) -> bool: + """Check whether a full validation metric is supported.""" + return normalize_full_validation_metric(metric) in FULL_VALIDATION_METRIC_PREFS + + +def get_full_validation_metric_prefactors(metric: str) -> tuple[str, str]: + """Get the prefactor keys required by a full validation metric.""" + normalized_metric = normalize_full_validation_metric(metric) + if normalized_metric not in FULL_VALIDATION_METRIC_PREFS: + valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) + raise ValueError( + f"validating.validation_metric must be one of {valid_metrics}, got {metric!r}." + ) + return FULL_VALIDATION_METRIC_PREFS[normalized_metric] + + +def resolve_full_validation_start_step( + full_val_start: float, num_steps: int +) -> int | None: + """Resolve the first step at which full validation becomes active.""" + start_value = float(full_val_start) + if start_value == 1.0: + return None + if 0.0 <= start_value < 1.0: + return int(num_steps * start_value) + return int(start_value) + + +def validating_args() -> Argument: + """Generate full validation arguments.""" + valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) + doc_full_validation = ( + "Whether to run an additional full validation pass over the entire " + "validation dataset during training. This flow is independent from the " + "display-time validation controlled by `training.disp_freq`. Only " + "single-task energy training is supported. Multi-task, spin-energy, " + "and `training.zero_stage >= 2` are not supported." + ) + doc_validation_freq = ( + "The frequency, in training steps, of running the full validation pass." + ) + doc_save_best = ( + "Whether to save an extra checkpoint when the selected full validation " + "metric reaches a new best value." + ) + doc_max_best_ckpt = ( + "The maximum number of top-ranked best checkpoints to keep. The best " + "checkpoints are ranked by the selected validation metric in ascending " + "order. Default is 1." + ) + doc_validation_metric = ( + "Metric used to determine the best checkpoint during full validation. " + f"Supported values are {valid_metrics}. The string is case-insensitive. " + "`E` and `V` are per-atom metrics; `F` uses component-wise force errors, " + "matching `dp test`. The corresponding loss prefactors must not both be 0." + ) + doc_full_val_file = ( + "The file for writing full validation results only. This file is " + "independent from `training.disp_file`." + ) + doc_full_val_start = ( + "The starting point of full validation. `0` means the feature is active " + "from the beginning and will trigger at every `validation_freq` steps. " + "A value in `(0, 1)` is interpreted as a ratio of `training.numb_steps`. " + "`1` disables the feature. A value larger than `1` is interpreted as the " + "starting step after integer conversion." + ) + args = [ + Argument( + "full_validation", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_full_validation, + ), + Argument( + "validation_freq", + int, + optional=True, + default=5000, + doc=doc_only_pt_supported + doc_validation_freq, + extra_check=lambda x: x > 0, + extra_check_errmsg="must be greater than 0", + ), + Argument( + "save_best", + bool, + optional=True, + default=True, + doc=doc_only_pt_supported + doc_save_best, + ), + Argument( + "max_best_ckpt", + int, + optional=True, + default=1, + doc=doc_only_pt_supported + doc_max_best_ckpt, + extra_check=lambda x: x > 0, + extra_check_errmsg="must be greater than 0", + ), + Argument( + "validation_metric", + str, + optional=True, + default="E:MAE", + doc=doc_only_pt_supported + doc_validation_metric, + extra_check=is_valid_full_validation_metric, + extra_check_errmsg=( + "must be one of " + + ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) + ), + ), + Argument( + "full_val_file", + str, + optional=True, + default="val.log", + doc=doc_only_pt_supported + doc_full_val_file, + ), + Argument( + "full_val_start", + [int, float], + optional=True, + default=0.5, + doc=doc_only_pt_supported + doc_full_val_start, + extra_check=lambda x: x >= 0, + extra_check_errmsg="must be greater than or equal to 0", + ), + ] + return Argument( + "validating", + dict, + sub_fields=args, + sub_variants=[], + optional=True, + default={}, + doc=doc_only_pt_supported + + "Independent full validation options for single-task energy training.", + ) + + +def validate_full_validation_config( + data: dict[str, Any], multi_task: bool = False +) -> None: + """Validate cross-section constraints for full validation.""" + validating = data.get("validating") or {} + training = data.get("training", {}) + if not validating.get("full_validation", False): + return + + metric = str(validating.get("validation_metric", "E:MAE")) + if not is_valid_full_validation_metric(metric): + valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) + raise ValueError( + "validating.validation_metric must be one of " + f"{valid_metrics}, got {metric!r}." + ) + + if multi_task: + raise ValueError( + "validating.full_validation only supports single-task energy " + "training; multi-task training is not supported." + ) + + loss_params = data.get("loss", {}) + loss_type = loss_params.get("type", "ener") + if loss_type == "ener_spin": + raise ValueError( + "validating.full_validation only supports single-task energy " + "training; spin-energy training is not supported." + ) + if loss_type != "ener": + raise ValueError( + "validating.full_validation only supports single-task energy " + f"training with loss.type='ener'; got loss.type={loss_type!r}." + ) + + if not training.get("validation_data"): + raise ValueError( + "validating.full_validation requires `training.validation_data`. " + "It is only supported for single-task energy training." + ) + + zero_stage = int(training.get("zero_stage", 0)) + if zero_stage >= 2: + raise ValueError( + "validating.full_validation only supports single-task energy " + "training with training.zero_stage < 2." + ) + + pref_start_key, pref_limit_key = get_full_validation_metric_prefactors(metric) + pref_start = float(loss_params.get(pref_start_key, 0.0)) + pref_limit = float(loss_params.get(pref_limit_key, 0.0)) + if pref_start == 0.0 or pref_limit == 0.0: + raise ValueError( + f"validating.validation_metric={metric!r} requires " + f"`loss.{pref_start_key}` and `loss.{pref_limit_key}` to both " + "be non-zero." + ) + + def multi_model_args() -> list[Argument]: model_dict = model_args() model_dict.name = "model_dict" @@ -4218,6 +4434,7 @@ def gen_args(multi_task: bool = False) -> list[Argument]: optimizer_args(), loss_args(), training_args(multi_task=multi_task), + validating_args(), nvnmd_args(), ] else: @@ -4227,6 +4444,7 @@ def gen_args(multi_task: bool = False) -> list[Argument]: optimizer_args(fold_subdoc=True), multi_loss_args(), training_args(multi_task=multi_task), + validating_args(), nvnmd_args(fold_subdoc=True), ] @@ -4258,10 +4476,15 @@ def gen_json_schema(multi_task: bool = False) -> str: return json.dumps(generate_json_schema(arg)) -def normalize(data: dict[str, Any], multi_task: bool = False) -> dict[str, Any]: +def normalize( + data: dict[str, Any], multi_task: bool = False, *, check: bool = True +) -> dict[str, Any]: + """Normalize config values and optionally run strict schema checks.""" base = Argument("base", dict, gen_args(multi_task=multi_task)) data = base.normalize_value(data, trim_pattern="_*") - base.check_value(data, strict=True) + if check: + base.check_value(data, strict=True) + validate_full_validation_config(data, multi_task=multi_task) return data diff --git a/deepmd/utils/batch_size.py b/deepmd/utils/batch_size.py index e701e82ec6..44c700947c 100644 --- a/deepmd/utils/batch_size.py +++ b/deepmd/utils/batch_size.py @@ -41,6 +41,8 @@ class AutoBatchSize(ABC): is not set factor : float, default: 2. increased factor + silent : bool, default: False + whether to suppress auto batch size informational logs Attributes ---------- @@ -52,8 +54,15 @@ class AutoBatchSize(ABC): minimal not working batch size """ - def __init__(self, initial_batch_size: int = 1024, factor: float = 2.0) -> None: + def __init__( + self, + initial_batch_size: int = 1024, + factor: float = 2.0, + *, + silent: bool = False, + ) -> None: # See also PyTorchLightning/pytorch-lightning#1638 + self.silent = silent self.current_batch_size = initial_batch_size DP_INFER_BATCH_SIZE = int(os.environ.get("DP_INFER_BATCH_SIZE", 0)) if DP_INFER_BATCH_SIZE > 0: @@ -68,11 +77,12 @@ def __init__(self, initial_batch_size: int = 1024, factor: float = 2.0) -> None: self.minimal_not_working_batch_size = ( self.maximum_working_batch_size + 1 ) - log.warning( - "You can use the environment variable DP_INFER_BATCH_SIZE to" - "control the inference batch size (nframes * natoms). " - f"The default value is {initial_batch_size}." - ) + if not self.silent: + log.warning( + "You can use the environment variable DP_INFER_BATCH_SIZE to" + "control the inference batch size (nframes * natoms). " + f"The default value is {initial_batch_size}." + ) self.factor = factor @@ -143,9 +153,10 @@ def execute( def _adjust_batch_size(self, factor: float) -> None: old_batch_size = self.current_batch_size self.current_batch_size = int(self.current_batch_size * factor) - log.info( - f"Adjust batch size from {old_batch_size} to {self.current_batch_size}" - ) + if not self.silent: + log.info( + f"Adjust batch size from {old_batch_size} to {self.current_batch_size}" + ) def execute_all( self, diff --git a/deepmd/utils/eval_metrics.py b/deepmd/utils/eval_metrics.py new file mode 100644 index 0000000000..ed210c9b78 --- /dev/null +++ b/deepmd/utils/eval_metrics.py @@ -0,0 +1,226 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +from __future__ import ( + annotations, +) + +from dataclasses import ( + dataclass, +) + +import numpy as np + +FULL_VALIDATION_METRIC_KEY_MAP = { + "e:mae": "mae_e_per_atom", + "e:rmse": "rmse_e_per_atom", + "f:mae": "mae_f", + "f:rmse": "rmse_f", + "v:mae": "mae_v_per_atom", + "v:rmse": "rmse_v_per_atom", +} +FULL_VALIDATION_WEIGHTED_METRIC_KEYS = { + "energy_per_atom": ("mae_e_per_atom", "rmse_e_per_atom"), + "force": ("mae_f", "rmse_f"), + "virial_per_atom": ("mae_v_per_atom", "rmse_v_per_atom"), +} +FULL_VALIDATION_METRIC_FAMILY_BY_KEY = { + "mae_e_per_atom": "e", + "rmse_e_per_atom": "e", + "mae_f": "f", + "rmse_f": "f", + "mae_v_per_atom": "v", + "rmse_v_per_atom": "v", +} +DP_TEST_WEIGHTED_METRIC_KEYS = { + "energy": ("mae_e", "rmse_e"), + "energy_per_atom": ("mae_ea", "rmse_ea"), + "force": ("mae_f", "rmse_f"), + "virial": ("mae_v", "rmse_v"), + "virial_per_atom": ("mae_va", "rmse_va"), +} +DP_TEST_SPIN_WEIGHTED_METRIC_KEYS = { + "force_real": ("mae_fr", "rmse_fr"), + "force_magnetic": ("mae_fm", "rmse_fm"), +} +DP_TEST_WEIGHTED_FORCE_METRIC_KEYS = ("mae_fw", "rmse_fw") +DP_TEST_HESSIAN_METRIC_KEYS = ("mae_h", "rmse_h") + + +def mae(diff: np.ndarray) -> float: + """Calculate mean absolute error.""" + return float(np.mean(np.abs(diff))) + + +def rmse(diff: np.ndarray) -> float: + """Calculate root mean square error.""" + return float(np.sqrt(np.mean(diff * diff))) + + +@dataclass(frozen=True) +class ErrorStat: + """One weighted MAE/RMSE pair.""" + + mae: float + rmse: float + weight: float + + def as_weighted_average_errors( + self, + mae_key: str, + rmse_key: str, + ) -> dict[str, tuple[float, float]]: + """Convert one metric pair into `weighted_average` inputs.""" + return { + mae_key: (self.mae, self.weight), + rmse_key: (self.rmse, self.weight), + } + + +@dataclass(frozen=True) +class EnergyTypeEvalMetrics: + """Shared energy-type metrics for one evaluation batch or system.""" + + energy: ErrorStat | None = None + energy_per_atom: ErrorStat | None = None + force: ErrorStat | None = None + virial: ErrorStat | None = None + virial_per_atom: ErrorStat | None = None + + def as_weighted_average_errors( + self, + metric_keys: dict[str, tuple[str, str]], + ) -> dict[str, tuple[float, float]]: + """Project shared metrics into caller-specific error dict keys.""" + errors: dict[str, tuple[float, float]] = {} + for metric_name, (mae_key, rmse_key) in metric_keys.items(): + metric = getattr(self, metric_name) + if metric is not None: + errors.update(metric.as_weighted_average_errors(mae_key, rmse_key)) + return errors + + +@dataclass(frozen=True) +class SpinForceEvalMetrics: + """Shared spin-force metrics for one evaluation batch or system.""" + + force_real: ErrorStat | None = None + force_magnetic: ErrorStat | None = None + + def as_weighted_average_errors( + self, + metric_keys: dict[str, tuple[str, str]], + ) -> dict[str, tuple[float, float]]: + """Project shared spin metrics into caller-specific error dict keys.""" + errors: dict[str, tuple[float, float]] = {} + for metric_name, (mae_key, rmse_key) in metric_keys.items(): + metric = getattr(self, metric_name) + if metric is not None: + errors.update(metric.as_weighted_average_errors(mae_key, rmse_key)) + return errors + + +def compute_error_stat( + prediction: np.ndarray, + reference: np.ndarray, + *, + scale: float = 1.0, +) -> ErrorStat: + """Compute one MAE/RMSE pair from aligned prediction and reference arrays.""" + diff = prediction - reference + return ErrorStat( + mae=mae(diff) * scale, + rmse=rmse(diff) * scale, + weight=float(diff.size), + ) + + +def compute_weighted_error_stat( + prediction: np.ndarray, + reference: np.ndarray, + weight: np.ndarray, +) -> ErrorStat: + """Compute weighted MAE/RMSE from aligned prediction and reference arrays.""" + diff = prediction - reference + weight_sum = float(np.sum(weight)) + if weight_sum <= 0.0: + return ErrorStat(mae=0.0, rmse=0.0, weight=weight_sum) + return ErrorStat( + mae=float(np.sum(np.abs(diff) * weight) / weight_sum), + rmse=float(np.sqrt(np.sum(diff * diff * weight) / weight_sum)), + weight=weight_sum, + ) + + +def compute_energy_type_metrics( + prediction: dict[str, np.ndarray], + test_data: dict[str, np.ndarray], + natoms: int, + has_pbc: bool, +) -> EnergyTypeEvalMetrics: + """Compute shared energy-type metrics for one evaluation dataset.""" + energy = None + energy_per_atom = None + force = None + virial = None + virial_per_atom = None + + if bool(test_data.get("find_energy", 0.0)): + energy = compute_error_stat( + prediction["energy"].reshape(-1, 1), + test_data["energy"].reshape(-1, 1), + ) + energy_per_atom = compute_error_stat( + prediction["energy"].reshape(-1, 1), + test_data["energy"].reshape(-1, 1), + scale=1.0 / natoms, + ) + + if bool(test_data.get("find_force", 0.0)): + force = compute_error_stat( + prediction["force"].reshape(-1), + test_data["force"].reshape(-1), + ) + + if has_pbc and bool(test_data.get("find_virial", 0.0)): + virial = compute_error_stat( + prediction["virial"].reshape(-1, 9), + test_data["virial"].reshape(-1, 9), + ) + virial_per_atom = compute_error_stat( + prediction["virial"].reshape(-1, 9), + test_data["virial"].reshape(-1, 9), + scale=1.0 / natoms, + ) + + return EnergyTypeEvalMetrics( + energy=energy, + energy_per_atom=energy_per_atom, + force=force, + virial=virial, + virial_per_atom=virial_per_atom, + ) + + +def compute_spin_force_metrics( + force_real_prediction: np.ndarray, + force_real_reference: np.ndarray, + *, + force_magnetic_prediction: np.ndarray | None = None, + force_magnetic_reference: np.ndarray | None = None, +) -> SpinForceEvalMetrics: + """Compute spin-aware force metrics from aligned real and magnetic forces.""" + force_real = compute_error_stat(force_real_prediction, force_real_reference) + force_magnetic = None + if force_magnetic_prediction is not None or force_magnetic_reference is not None: + if force_magnetic_prediction is None or force_magnetic_reference is None: + raise ValueError( + "Spin magnetic force metrics require both prediction and reference." + ) + force_magnetic = compute_error_stat( + force_magnetic_prediction, + force_magnetic_reference, + ) + return SpinForceEvalMetrics( + force_real=force_real, + force_magnetic=force_magnetic, + ) diff --git a/source/tests/pt/test_dp_test.py b/source/tests/pt/test_dp_test.py index 1c11541e50..b5c6365a57 100644 --- a/source/tests/pt/test_dp_test.py +++ b/source/tests/pt/test_dp_test.py @@ -84,28 +84,32 @@ def _run_dp_test( pred_f, to_numpy_array(result["force"]).reshape(-1, 3), ) - pred_v = np.loadtxt(self.detail_file + ".v.out", ndmin=2)[:, 9:18] - np.testing.assert_almost_equal( - pred_v, - to_numpy_array(result["virial"]), - ) - pred_v_peratom = np.loadtxt(self.detail_file + ".v_peratom.out", ndmin=2)[ - :, 9:18 - ] - np.testing.assert_almost_equal(pred_v_peratom, pred_v / natom) + if os.path.exists(self.detail_file + ".v.out"): + pred_v = np.loadtxt(self.detail_file + ".v.out", ndmin=2)[:, 9:18] + np.testing.assert_almost_equal( + pred_v, + to_numpy_array(result["virial"]), + ) + pred_v_peratom = np.loadtxt( + self.detail_file + ".v_peratom.out", ndmin=2 + )[:, 9:18] + np.testing.assert_almost_equal(pred_v_peratom, pred_v / natom) + else: + self.assertFalse(os.path.exists(self.detail_file + ".v_peratom.out")) else: pred_fr = np.loadtxt(self.detail_file + ".fr.out", ndmin=2)[:, 3:6] np.testing.assert_almost_equal( pred_fr, to_numpy_array(result["force"]).reshape(-1, 3), ) - pred_fm = np.loadtxt(self.detail_file + ".fm.out", ndmin=2)[:, 3:6] - np.testing.assert_almost_equal( - pred_fm, - to_numpy_array( - result["force_mag"][result["mask_mag"].bool().squeeze(-1)] - ).reshape(-1, 3), - ) + if os.path.exists(self.detail_file + ".fm.out"): + pred_fm = np.loadtxt(self.detail_file + ".fm.out", ndmin=2)[:, 3:6] + np.testing.assert_almost_equal( + pred_fm, + to_numpy_array( + result["force_mag"][result["mask_mag"].bool().squeeze(-1)] + ).reshape(-1, 3), + ) def test_dp_test_1_frame(self) -> None: self._run_dp_test(False) diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index b4cc926844..e20070e714 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -2,6 +2,7 @@ import json import os import shutil +import tempfile import unittest from copy import ( deepcopy, @@ -9,6 +10,9 @@ from pathlib import ( Path, ) +from unittest.mock import ( + patch, +) import numpy as np import torch @@ -20,8 +24,15 @@ from deepmd.pt.utils.finetune import ( get_finetune_rules, ) +from deepmd.pt.utils.multi_task import ( + preprocess_shared_params, +) +from deepmd.utils.argcheck import ( + normalize, +) from deepmd.utils.compat import ( convert_optimizer_v31_to_v32, + update_deepmd_input, ) from .model.test_permutation import ( @@ -749,5 +760,140 @@ def test_fitting_stat_consistency(self) -> None: ) +class TestFullValidation(unittest.TestCase): + def setUp(self) -> None: + self._cwd = os.getcwd() + self._tmpdir = tempfile.TemporaryDirectory() + os.chdir(self._tmpdir.name) + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.config = convert_optimizer_v31_to_v32(self.config, warning=False) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_se_e2_a) + self.config["training"]["numb_steps"] = 4 + self.config["training"]["save_freq"] = 100 + self.config["training"]["disp_training"] = False + self.config["validating"] = { + "full_validation": True, + "validation_freq": 1, + "save_best": True, + "max_best_ckpt": 2, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + } + + def tearDown(self) -> None: + os.chdir(self._cwd) + self._tmpdir.cleanup() + + @patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems") + def test_full_validation_rotates_best_checkpoint(self, mocked_eval) -> None: + mocked_eval.side_effect = [ + {"mae_e_per_atom": 1.0}, + {"mae_e_per_atom": 2.0}, + {"mae_e_per_atom": 0.5}, + {"mae_e_per_atom": 1.5}, + ] + Path("best.ckpt-999.t-1.pt").touch() + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + + self.assertFalse(Path("best.ckpt-999.t-1.pt").exists()) + self.assertFalse(Path("best.ckpt-1.t-1.pt").exists()) + self.assertFalse(Path("best.ckpt-2.t-1.pt").exists()) + self.assertTrue(Path("best.ckpt-3.t-1.pt").exists()) + self.assertTrue(Path("best.ckpt-1.t-2.pt").exists()) + train_infos = trainer._get_inner_module().train_infos + self.assertEqual( + train_infos["full_validation_topk_records"], + [ + {"metric": 0.5, "step": 3}, + {"metric": 1.0, "step": 1}, + ], + ) + with open("val.log") as fp: + val_lines = [line for line in fp.readlines() if not line.startswith("#")] + self.assertEqual(len(val_lines), 4) + self.assertEqual(val_lines[0].split()[1], "1000.0") + self.assertEqual(val_lines[1].split()[1], "2000.0") + + @patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems") + def test_full_validation_runs_when_start_step_is_final_step( + self, mocked_eval + ) -> None: + mocked_eval.return_value = {"mae_e_per_atom": 1.0} + config = deepcopy(self.config) + config["validating"]["full_val_start"] = config["training"]["numb_steps"] + + trainer = get_trainer(config) + trainer.run() + + mocked_eval.assert_called_once() + with open("val.log") as fp: + val_lines = [line for line in fp.readlines() if not line.startswith("#")] + self.assertEqual(len(val_lines), 1) + + def test_full_validation_uses_normalized_defaults_in_get_trainer(self) -> None: + config = deepcopy(self.config) + config["validating"] = {"full_validation": True} + normalized = normalize(update_deepmd_input(deepcopy(config), warning=False)) + + trainer = get_trainer(config) + + self.assertIsNotNone(trainer.full_validator) + assert trainer.full_validator is not None + self.assertEqual( + trainer.full_validator.validation_freq, + normalized["validating"]["validation_freq"], + ) + self.assertEqual( + trainer.full_validator.start_step, + int( + normalized["training"]["numb_steps"] + * normalized["validating"]["full_val_start"] + ), + ) + + def test_full_validation_rejects_spin_loss(self) -> None: + config = deepcopy(self.config) + config["loss"]["type"] = "ener_spin" + with self.assertRaisesRegex(ValueError, "spin-energy"): + get_trainer(config) + + def test_full_validation_rejects_multitask(self) -> None: + multitask_json = str(Path(__file__).parent / "water/multitask.json") + with open(multitask_json) as f: + config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + for model_key in config["training"]["data_dict"]: + config["training"]["data_dict"][model_key]["training_data"]["systems"] = ( + data_file + ) + config["training"]["data_dict"][model_key]["validation_data"]["systems"] = ( + data_file + ) + config["training"]["data_dict"][model_key]["stat_file"] = ( + f"stat_files_{model_key}" + ) + config["training"]["numb_steps"] = 1 + config["training"]["save_freq"] = 1 + config["validating"] = { + "full_validation": True, + "validation_freq": 1, + "save_best": True, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + } + config["model"], _ = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + with self.assertRaisesRegex(ValueError, "multi-task"): + normalize(config, multi_task=True) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt/test_validation.py b/source/tests/pt/test_validation.py new file mode 100644 index 0000000000..f2e024aed2 --- /dev/null +++ b/source/tests/pt/test_validation.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os +import tempfile +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +import torch +from dargs.dargs import ( + ArgumentValueError, +) + +from deepmd.pt.train.validation import ( + BEST_METRIC_NAME_INFO_KEY, + TOPK_RECORDS_INFO_KEY, + FullValidator, + resolve_full_validation_start_step, +) +from deepmd.utils.argcheck import ( + normalize, +) + +from .model.test_permutation import ( + model_se_e2_a, +) + + +class _DummyValidationData: + def __init__(self) -> None: + self.systems = [] + + +class _DummyModel(torch.nn.Module): + def forward(self, *args, **kwargs): + raise NotImplementedError + + def get_dim_fparam(self) -> int: + return 0 + + def get_dim_aparam(self) -> int: + return 0 + + +def _make_single_task_config() -> dict: + return { + "model": deepcopy(model_se_e2_a), + "learning_rate": { + "type": "exp", + "start_lr": 0.001, + "stop_lr": 1e-8, + "decay_steps": 10, + }, + "optimizer": { + "type": "Adam", + }, + "loss": { + "type": "ener", + "start_pref_e": 1.0, + "limit_pref_e": 1.0, + "start_pref_f": 1.0, + "limit_pref_f": 1.0, + "start_pref_v": 1.0, + "limit_pref_v": 1.0, + }, + "training": { + "training_data": {"systems": ["train_system"]}, + "validation_data": {"systems": ["valid_system"]}, + "numb_steps": 10, + }, + "validating": { + "full_validation": True, + "validation_freq": 2, + "save_best": True, + "max_best_ckpt": 1, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + }, + } + + +class TestValidationHelpers(unittest.TestCase): + def test_resolve_full_validation_start_step(self) -> None: + self.assertEqual(resolve_full_validation_start_step(0, 2000000), 0) + self.assertEqual(resolve_full_validation_start_step(0.1, 2000000), 200000) + self.assertEqual(resolve_full_validation_start_step(5000, 2000000), 5000) + self.assertIsNone(resolve_full_validation_start_step(1, 2000000)) + + def test_full_validator_rotates_best_checkpoint(self) -> None: + train_infos = {} + with tempfile.TemporaryDirectory() as tmpdir: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + validator = FullValidator( + validating_params={ + "full_validation": True, + "validation_freq": 1, + "save_best": True, + "max_best_ckpt": 2, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + }, + validation_data=_DummyValidationData(), + model=_DummyModel(), + train_infos=train_infos, + num_steps=10, + rank=0, + zero_stage=0, + restart_training=False, + ) + new_best_path = validator._update_best_state( + display_step=1, + selected_metric_value=2.0, + ) + Path(new_best_path).touch() + validator._reconcile_best_checkpoints() + + new_best_path = validator._update_best_state( + display_step=2, + selected_metric_value=1.0, + ) + Path(new_best_path).touch() + validator._reconcile_best_checkpoints() + + new_best_path = validator._update_best_state( + display_step=3, + selected_metric_value=1.5, + ) + Path(new_best_path).touch() + validator._reconcile_best_checkpoints() + finally: + os.chdir(old_cwd) + + self.assertEqual(new_best_path, "best.ckpt-3.t-2.pt") + self.assertEqual( + sorted(path.name for path in Path(tmpdir).glob("best.ckpt-*.pt")), + ["best.ckpt-2.t-1.pt", "best.ckpt-3.t-2.pt"], + ) + self.assertEqual( + train_infos[TOPK_RECORDS_INFO_KEY], + [ + {"metric": 1.0, "step": 2}, + {"metric": 1.5, "step": 3}, + ], + ) + self.assertEqual(train_infos[BEST_METRIC_NAME_INFO_KEY], "e:mae") + + def test_full_validator_restores_top_k_checkpoints(self) -> None: + train_infos = { + BEST_METRIC_NAME_INFO_KEY: "e:mae", + TOPK_RECORDS_INFO_KEY: [ + {"metric": 1.0, "step": 20}, + {"metric": 2.0, "step": 10}, + ], + } + with tempfile.TemporaryDirectory() as tmpdir: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + Path("best.ckpt-20.t-9.pt").touch() + Path("best.ckpt-10.t-8.pt").touch() + Path("best.ckpt-999.t-1.pt").touch() + FullValidator( + validating_params={ + "full_validation": True, + "validation_freq": 1, + "save_best": True, + "max_best_ckpt": 2, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + }, + validation_data=_DummyValidationData(), + model=_DummyModel(), + train_infos=train_infos, + num_steps=10, + rank=0, + zero_stage=0, + restart_training=True, + ) + finally: + os.chdir(old_cwd) + + self.assertEqual( + sorted(path.name for path in Path(tmpdir).glob("best.ckpt-*.pt")), + ["best.ckpt-10.t-2.pt", "best.ckpt-20.t-1.pt"], + ) + + +class TestValidationArgcheck(unittest.TestCase): + def test_normalize_rejects_missing_validation_data(self) -> None: + config = _make_single_task_config() + del config["training"]["validation_data"] + with self.assertRaisesRegex(ValueError, "training.validation_data"): + normalize(config) + + def test_normalize_rejects_inactive_prefactor_metric(self) -> None: + for start_pref_f, limit_pref_f in ((0.0, 0.0), (1.0, 0.0), (0.0, 1.0)): + with self.subTest( + start_pref_f=start_pref_f, + limit_pref_f=limit_pref_f, + ): + config = _make_single_task_config() + config["validating"]["validation_metric"] = "F:RMSE" + config["loss"]["start_pref_f"] = start_pref_f + config["loss"]["limit_pref_f"] = limit_pref_f + with self.assertRaisesRegex(ValueError, "start_pref_f"): + normalize(config) + + def test_normalize_rejects_invalid_metric(self) -> None: + config = _make_single_task_config() + config["validating"]["validation_metric"] = "X:MAE" + with self.assertRaisesRegex(ArgumentValueError, "validation_metric"): + normalize(config) + + def test_normalize_rejects_invalid_metric_with_num_epoch_schedule(self) -> None: + config = _make_single_task_config() + del config["training"]["numb_steps"] + config["training"]["numb_epoch"] = 1.0 + config["validating"]["validation_metric"] = "F:RMSE" + config["validating"]["full_val_start"] = 2 + config["loss"]["limit_pref_f"] = 0.0 + with self.assertRaisesRegex(ValueError, "start_pref_f"): + normalize(config) + + def test_normalize_rejects_nonpositive_max_best_ckpt(self) -> None: + config = _make_single_task_config() + config["validating"]["max_best_ckpt"] = 0 + with self.assertRaisesRegex(ArgumentValueError, "max_best_ckpt"): + normalize(config)