diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index bca0759f5c..b067ca94dc 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import unittest +from copy import ( + deepcopy, +) from typing import ( Any, ) @@ -21,7 +24,7 @@ INSTALLED_PT, INSTALLED_PT_EXPT, CommonTest, - parameterized, + parameterized_cases, ) from .common import ( DescriptorAPITest, @@ -62,24 +65,127 @@ descrpt_dpa3_args, ) +DPA3_CASE_FIELDS = ( + "update_residual_init", + "exclude_types", + "update_angle", + "a_compress_rate", + "a_compress_e_rate", + "a_compress_use_split", + "optim_update", + "edge_init_use_dist", + "use_exp_switch", + "use_dynamic_sel", + "use_loc_mapping", + "fix_stat_std", + "n_multi_edge_message", + "precision", + "add_chg_spin_ebd", +) + -@parameterized( - ("const",), # update_residual_init - ([], [[0, 1]]), # exclude_types - (True,), # update_angle - (0, 1), # a_compress_rate - (1, 2), # a_compress_e_rate - (True,), # a_compress_use_split - (True, False), # optim_update - (True, False), # edge_init_use_dist - (True, False), # use_exp_switch - (True, False), # use_dynamic_sel - (True, False), # use_loc_mapping - (0.3,), # fix_stat_std - (1,), # n_multi_edge_message - ("float64",), # precision - (False, True), # add_chg_spin_ebd +DPA3_BASELINE_CASE = { + "update_residual_init": "const", + "exclude_types": [], + "update_angle": True, + "a_compress_rate": 0, + "a_compress_e_rate": 1, + "a_compress_use_split": True, + "optim_update": True, + "edge_init_use_dist": True, + "use_exp_switch": True, + "use_dynamic_sel": True, + "use_loc_mapping": True, + "fix_stat_std": 0.3, + "n_multi_edge_message": 1, + "precision": "float64", + "add_chg_spin_ebd": False, +} + + +def _build_dpa3_case(fields: tuple[str, ...], **overrides: Any) -> tuple: + unknown = set(overrides) - set(DPA3_BASELINE_CASE) + if unknown: + raise KeyError(f"Unknown DPA3 case override(s): {sorted(unknown)}") + case = deepcopy(DPA3_BASELINE_CASE) + case.update(overrides) + return tuple(case[field] for field in fields) + + +def dpa3_case(**overrides: Any) -> tuple: + return _build_dpa3_case(DPA3_CASE_FIELDS, **overrides) + + +DPA3_CURATED_CASES = ( + # Baseline coverage. + dpa3_case(), + # Descriptor-level edge cases. + dpa3_case(exclude_types=[[0, 1]]), + dpa3_case(use_loc_mapping=False), + dpa3_case(add_chg_spin_ebd=True), + # Repflow compression branches. + dpa3_case(a_compress_rate=1), + dpa3_case(a_compress_e_rate=2), + # Repflow update toggles. + dpa3_case(optim_update=False), + dpa3_case(edge_init_use_dist=False), + dpa3_case(use_exp_switch=False), + dpa3_case(use_dynamic_sel=False), + # One mixed high-risk path to keep interactions covered. + dpa3_case( + exclude_types=[[0, 1]], + a_compress_rate=1, + a_compress_e_rate=2, + optim_update=False, + edge_init_use_dist=False, + use_exp_switch=False, + use_dynamic_sel=False, + use_loc_mapping=False, + add_chg_spin_ebd=True, + ), ) + + +DPA3_DESCRIPTOR_API_CASE_FIELDS = DPA3_CASE_FIELDS + + +def dpa3_descriptor_api_case(**overrides: Any) -> tuple: + return _build_dpa3_case(DPA3_DESCRIPTOR_API_CASE_FIELDS, **overrides) + + +DPA3_DESCRIPTOR_API_CURATED_CASES = ( + # Baseline coverage. + dpa3_descriptor_api_case(), + # Descriptor serialization / config toggles. + dpa3_descriptor_api_case(exclude_types=[[0, 1]]), + dpa3_descriptor_api_case(use_loc_mapping=False), + dpa3_descriptor_api_case(fix_stat_std=0.0), + dpa3_descriptor_api_case(add_chg_spin_ebd=True), + # Repflow compression branches. + dpa3_descriptor_api_case(a_compress_rate=1), + dpa3_descriptor_api_case(a_compress_e_rate=2), + # Repflow update toggles. + dpa3_descriptor_api_case(optim_update=False), + dpa3_descriptor_api_case(edge_init_use_dist=False), + dpa3_descriptor_api_case(use_exp_switch=False), + dpa3_descriptor_api_case(use_dynamic_sel=False), + # One mixed high-risk path to keep interactions covered. + dpa3_descriptor_api_case( + exclude_types=[[0, 1]], + a_compress_rate=1, + a_compress_e_rate=2, + optim_update=False, + edge_init_use_dist=False, + use_exp_switch=False, + use_dynamic_sel=False, + use_loc_mapping=False, + fix_stat_std=0.0, + add_chg_spin_ebd=True, + ), +) + + +@parameterized_cases(*DPA3_CURATED_CASES) class TestDPA3(CommonTest, DescriptorTest, unittest.TestCase): @property def data(self) -> dict: @@ -145,41 +251,41 @@ def data(self) -> dict: @property def skip_pt(self) -> bool: ( - update_residual_init, - exclude_types, - update_angle, - a_compress_rate, - a_compress_e_rate, - a_compress_use_split, - optim_update, - edge_init_use_dist, - use_exp_switch, - use_dynamic_sel, - use_loc_mapping, - fix_stat_std, - n_multi_edge_message, - precision, - add_chg_spin_ebd, + _update_residual_init, + _exclude_types, + _update_angle, + _a_compress_rate, + _a_compress_e_rate, + _a_compress_use_split, + _optim_update, + _edge_init_use_dist, + _use_exp_switch, + _use_dynamic_sel, + _use_loc_mapping, + _fix_stat_std, + _n_multi_edge_message, + _precision, + _add_chg_spin_ebd, ) = self.param return CommonTest.skip_pt @property def skip_pd(self) -> bool: ( - update_residual_init, - exclude_types, - update_angle, - a_compress_rate, - a_compress_e_rate, - a_compress_use_split, - optim_update, - edge_init_use_dist, - use_exp_switch, - use_dynamic_sel, - use_loc_mapping, - fix_stat_std, - n_multi_edge_message, - precision, + _update_residual_init, + _exclude_types, + _update_angle, + _a_compress_rate, + _a_compress_e_rate, + _a_compress_use_split, + _optim_update, + _edge_init_use_dist, + _use_exp_switch, + _use_dynamic_sel, + _use_loc_mapping, + _fix_stat_std, + _n_multi_edge_message, + _precision, add_chg_spin_ebd, ) = self.param return True if add_chg_spin_ebd else CommonTest.skip_pd @@ -187,42 +293,42 @@ def skip_pd(self) -> bool: @property def skip_dp(self) -> bool: ( - update_residual_init, - exclude_types, - update_angle, - a_compress_rate, - a_compress_e_rate, - a_compress_use_split, - optim_update, - edge_init_use_dist, - use_exp_switch, - use_dynamic_sel, - use_loc_mapping, - fix_stat_std, - n_multi_edge_message, - precision, - add_chg_spin_ebd, + _update_residual_init, + _exclude_types, + _update_angle, + _a_compress_rate, + _a_compress_e_rate, + _a_compress_use_split, + _optim_update, + _edge_init_use_dist, + _use_exp_switch, + _use_dynamic_sel, + _use_loc_mapping, + _fix_stat_std, + _n_multi_edge_message, + _precision, + _add_chg_spin_ebd, ) = self.param return CommonTest.skip_dp @property def skip_tf(self) -> bool: ( - update_residual_init, - exclude_types, - update_angle, - a_compress_rate, - a_compress_e_rate, - a_compress_use_split, - optim_update, - edge_init_use_dist, - use_exp_switch, - use_dynamic_sel, - use_loc_mapping, - fix_stat_std, - n_multi_edge_message, - precision, - add_chg_spin_ebd, + _update_residual_init, + _exclude_types, + _update_angle, + _a_compress_rate, + _a_compress_e_rate, + _a_compress_use_split, + _optim_update, + _edge_init_use_dist, + _use_exp_switch, + _use_dynamic_sel, + _use_loc_mapping, + _fix_stat_std, + _n_multi_edge_message, + _precision, + _add_chg_spin_ebd, ) = self.param return True @@ -273,20 +379,20 @@ def setUp(self) -> None: ) self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) ( - update_residual_init, - exclude_types, - update_angle, - a_compress_rate, - a_compress_e_rate, - a_compress_use_split, - optim_update, - edge_init_use_dist, - use_exp_switch, - use_dynamic_sel, - use_loc_mapping, - fix_stat_std, - n_multi_edge_message, - precision, + _update_residual_init, + _exclude_types, + _update_angle, + _a_compress_rate, + _a_compress_e_rate, + _a_compress_use_split, + _optim_update, + _edge_init_use_dist, + _use_exp_switch, + _use_dynamic_sel, + _use_loc_mapping, + _fix_stat_std, + _n_multi_edge_message, + _precision, add_chg_spin_ebd, ) = self.param # fparam for charge=5, spin=1 when add_chg_spin_ebd is True @@ -379,21 +485,21 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: def rtol(self) -> float: """Relative tolerance for comparing the return value.""" ( - update_residual_init, - exclude_types, - update_angle, - a_compress_rate, - a_compress_e_rate, - a_compress_use_split, - optim_update, - edge_init_use_dist, - use_exp_switch, - use_dynamic_sel, - use_loc_mapping, - fix_stat_std, - n_multi_edge_message, + _update_residual_init, + _exclude_types, + _update_angle, + _a_compress_rate, + _a_compress_e_rate, + _a_compress_use_split, + _optim_update, + _edge_init_use_dist, + _use_exp_switch, + _use_dynamic_sel, + _use_loc_mapping, + _fix_stat_std, + _n_multi_edge_message, precision, - add_chg_spin_ebd, + _add_chg_spin_ebd, ) = self.param if precision == "float64": return 1e-10 @@ -406,21 +512,21 @@ def rtol(self) -> float: def atol(self) -> float: """Absolute tolerance for comparing the return value.""" ( - update_residual_init, - exclude_types, - update_angle, - a_compress_rate, - a_compress_e_rate, - a_compress_use_split, - optim_update, - edge_init_use_dist, - use_exp_switch, - use_dynamic_sel, - use_loc_mapping, - fix_stat_std, - n_multi_edge_message, + _update_residual_init, + _exclude_types, + _update_angle, + _a_compress_rate, + _a_compress_e_rate, + _a_compress_use_split, + _optim_update, + _edge_init_use_dist, + _use_exp_switch, + _use_dynamic_sel, + _use_loc_mapping, + _fix_stat_std, + _n_multi_edge_message, precision, - add_chg_spin_ebd, + _add_chg_spin_ebd, ) = self.param if precision == "float64": return 1e-6 # need to fix in the future, see issue https://github.com/deepmodeling/deepmd-kit/issues/3786 @@ -430,22 +536,7 @@ def atol(self) -> float: raise ValueError(f"Unknown precision: {precision}") -@parameterized( - ("const",), # update_residual_init - ([], [[0, 1]]), # exclude_types - (True,), # update_angle - (0, 1), # a_compress_rate - (1, 2), # a_compress_e_rate - (True,), # a_compress_use_split - (True, False), # optim_update - (True, False), # edge_init_use_dist - (True, False), # use_exp_switch - (True, False), # use_dynamic_sel - (True, False), # use_loc_mapping - (0.3, 0.0), # fix_stat_std - (1,), # n_multi_edge_message - ("float64",), # precision -) +@parameterized_cases(*DPA3_DESCRIPTOR_API_CURATED_CASES) class TestDPA3DescriptorAPI(DescriptorAPITest, unittest.TestCase): """Test consistency of BaseDescriptor API methods across backends.""" @@ -471,6 +562,7 @@ def data(self) -> dict: fix_stat_std, n_multi_edge_message, precision, + add_chg_spin_ebd, ) = self.param return { "ntypes": self.ntypes, @@ -511,4 +603,5 @@ def data(self) -> dict: "env_protection": 0.0, "use_loc_mapping": use_loc_mapping, "trainable": False, + "add_chg_spin_ebd": add_chg_spin_ebd, }