Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 116 additions & 33 deletions source/tests/consistent/descriptor/test_dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
INSTALLED_PT,
INSTALLED_PT_EXPT,
CommonTest,
parameterized,
parameterized_cases,
)
from .common import (
DescriptorAPITest,
Expand Down Expand Up @@ -62,24 +62,120 @@
descrpt_dpa3_args,
)

DPA3_CASE_FIELDS = (
"update_residual_init",
"exclude_types",
"update_angle",
"a_compress_rate",
"a_compress_e_rate",
"a_compress_use_split",
"optim_update",
"edge_init_use_dist",
"use_exp_switch",
"use_dynamic_sel",
"use_loc_mapping",
"fix_stat_std",
"n_multi_edge_message",
"precision",
"add_chg_spin_ebd",
)


DPA3_BASELINE_CASE = {
"update_residual_init": "const",
"exclude_types": [],
"update_angle": True,
"a_compress_rate": 0,
"a_compress_e_rate": 1,
"a_compress_use_split": True,
"optim_update": True,
"edge_init_use_dist": True,
"use_exp_switch": True,
"use_dynamic_sel": True,
"use_loc_mapping": True,
"fix_stat_std": 0.3,
"n_multi_edge_message": 1,
"precision": "float64",
"add_chg_spin_ebd": False,
}


def dpa3_case(**overrides: Any) -> tuple:
case = DPA3_BASELINE_CASE | overrides
return tuple(case[field] for field in DPA3_CASE_FIELDS)


DPA3_CURATED_CASES = (
# Baseline coverage.
dpa3_case(),
# Descriptor-level edge cases.
dpa3_case(exclude_types=[[0, 1]]),
dpa3_case(use_loc_mapping=False),
dpa3_case(add_chg_spin_ebd=True),
# Repflow compression branches.
dpa3_case(a_compress_rate=1),
dpa3_case(a_compress_e_rate=2),
# Repflow update toggles.
dpa3_case(optim_update=False),
dpa3_case(edge_init_use_dist=False),
dpa3_case(use_exp_switch=False),
dpa3_case(use_dynamic_sel=False),
# One mixed high-risk path to keep interactions covered.
dpa3_case(
exclude_types=[[0, 1]],
a_compress_rate=1,
a_compress_e_rate=2,
optim_update=False,
edge_init_use_dist=False,
use_exp_switch=False,
use_dynamic_sel=False,
use_loc_mapping=False,
add_chg_spin_ebd=True,
),
)


@parameterized(
("const",), # update_residual_init
([], [[0, 1]]), # exclude_types
(True,), # update_angle
(0, 1), # a_compress_rate
(1, 2), # a_compress_e_rate
(True,), # a_compress_use_split
(True, False), # optim_update
(True, False), # edge_init_use_dist
(True, False), # use_exp_switch
(True, False), # use_dynamic_sel
(True, False), # use_loc_mapping
(0.3,), # fix_stat_std
(1,), # n_multi_edge_message
("float64",), # precision
(False, True), # add_chg_spin_ebd
DPA3_DESCRIPTOR_API_CASE_FIELDS = DPA3_CASE_FIELDS


def dpa3_descriptor_api_case(**overrides: Any) -> tuple:
case = DPA3_BASELINE_CASE | overrides
return tuple(case[field] for field in DPA3_DESCRIPTOR_API_CASE_FIELDS)


DPA3_DESCRIPTOR_API_CURATED_CASES = (
# Baseline coverage.
dpa3_descriptor_api_case(),
# Descriptor serialization / config toggles.
dpa3_descriptor_api_case(exclude_types=[[0, 1]]),
dpa3_descriptor_api_case(use_loc_mapping=False),
dpa3_descriptor_api_case(fix_stat_std=0.0),
dpa3_descriptor_api_case(add_chg_spin_ebd=True),
# Repflow compression branches.
dpa3_descriptor_api_case(a_compress_rate=1),
dpa3_descriptor_api_case(a_compress_e_rate=2),
# Repflow update toggles.
dpa3_descriptor_api_case(optim_update=False),
dpa3_descriptor_api_case(edge_init_use_dist=False),
dpa3_descriptor_api_case(use_exp_switch=False),
dpa3_descriptor_api_case(use_dynamic_sel=False),
# One mixed high-risk path to keep interactions covered.
dpa3_descriptor_api_case(
exclude_types=[[0, 1]],
a_compress_rate=1,
a_compress_e_rate=2,
optim_update=False,
edge_init_use_dist=False,
use_exp_switch=False,
use_dynamic_sel=False,
use_loc_mapping=False,
fix_stat_std=0.0,
add_chg_spin_ebd=True,
),
)


@parameterized_cases(*DPA3_CURATED_CASES)
class TestDPA3(CommonTest, DescriptorTest, unittest.TestCase):
@property
def data(self) -> dict:
Expand Down Expand Up @@ -430,22 +526,7 @@ def atol(self) -> float:
raise ValueError(f"Unknown precision: {precision}")


@parameterized(
("const",), # update_residual_init
([], [[0, 1]]), # exclude_types
(True,), # update_angle
(0, 1), # a_compress_rate
(1, 2), # a_compress_e_rate
(True,), # a_compress_use_split
(True, False), # optim_update
(True, False), # edge_init_use_dist
(True, False), # use_exp_switch
(True, False), # use_dynamic_sel
(True, False), # use_loc_mapping
(0.3, 0.0), # fix_stat_std
(1,), # n_multi_edge_message
("float64",), # precision
)
@parameterized_cases(*DPA3_DESCRIPTOR_API_CURATED_CASES)
class TestDPA3DescriptorAPI(DescriptorAPITest, unittest.TestCase):
"""Test consistency of BaseDescriptor API methods across backends."""

Expand All @@ -471,6 +552,7 @@ def data(self) -> dict:
fix_stat_std,
n_multi_edge_message,
precision,
add_chg_spin_ebd,
) = self.param
return {
"ntypes": self.ntypes,
Expand Down Expand Up @@ -511,4 +593,5 @@ def data(self) -> dict:
"env_protection": 0.0,
"use_loc_mapping": use_loc_mapping,
"trainable": False,
"add_chg_spin_ebd": add_chg_spin_ebd,
}
Loading