-
Notifications
You must be signed in to change notification settings - Fork 608
feat: support aparam derivative in ener loss #5285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
48e85df
ae0c36d
648665f
f439c25
26774dd
be7ef8f
8fe4486
7bba46d
10a844e
74c6fc5
b22679a
efdf69e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -57,6 +57,9 @@ def __init__( | |||||||
| inference: bool = False, | ||||||||
| use_huber: bool = False, | ||||||||
| huber_delta: float = 0.01, | ||||||||
| start_pref_ap: float = 0.0, | ||||||||
| limit_pref_ap: float = 0.0, | ||||||||
| numb_aparam: int = 0, | ||||||||
| **kwargs: Any, | ||||||||
| ) -> None: | ||||||||
| r"""Construct a layer to compute loss on energy, force and virial. | ||||||||
|
|
@@ -109,6 +112,12 @@ def __init__( | |||||||
| Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D). | ||||||||
| huber_delta : float | ||||||||
| The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss. | ||||||||
| start_pref_ap : float | ||||||||
| The prefactor of aparam gradient loss at the start of the training. | ||||||||
| limit_pref_ap : float | ||||||||
| The prefactor of aparam gradient loss at the end of the training. | ||||||||
| numb_aparam : int | ||||||||
| The dimension of atomic parameters. Required when aparam gradient loss is enabled. | ||||||||
| **kwargs | ||||||||
| Other keyword arguments. | ||||||||
| """ | ||||||||
|
|
@@ -151,6 +160,15 @@ def __init__( | |||||||
| "Huber loss is not implemented for force with atom_pref, generalized force and relative force. " | ||||||||
| ) | ||||||||
|
|
||||||||
| self.has_ap = start_pref_ap != 0.0 or limit_pref_ap != 0.0 | ||||||||
| if self.has_ap and numb_aparam == 0: | ||||||||
| raise RuntimeError( | ||||||||
| "numb_aparam must be > 0 when aparam gradient loss is enabled" | ||||||||
| ) | ||||||||
| self.start_pref_ap = start_pref_ap | ||||||||
| self.limit_pref_ap = limit_pref_ap | ||||||||
| self.numb_aparam = numb_aparam | ||||||||
|
|
||||||||
| def forward( | ||||||||
| self, | ||||||||
| input_dict: dict[str, torch.Tensor], | ||||||||
|
|
@@ -182,6 +200,16 @@ def forward( | |||||||
| more_loss: dict[str, torch.Tensor] | ||||||||
| Other losses for display. | ||||||||
| """ | ||||||||
| ap_for_grad: torch.Tensor | None = None | ||||||||
| if ( | ||||||||
| self.has_ap | ||||||||
| and input_dict.get("aparam") is not None | ||||||||
| and torch.is_grad_enabled() | ||||||||
| ): | ||||||||
| ap_for_grad = input_dict["aparam"].detach() | ||||||||
| ap_for_grad.requires_grad_(True) | ||||||||
| input_dict = {**input_dict, "aparam": ap_for_grad} | ||||||||
|
|
||||||||
| model_pred = model(**input_dict) | ||||||||
| coef = learning_rate / self.starter_learning_rate | ||||||||
| pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef | ||||||||
|
|
@@ -402,6 +430,37 @@ def forward( | |||||||
| rmse_ae.detach(), find_atom_ener | ||||||||
| ) | ||||||||
|
|
||||||||
| if ( | ||||||||
| self.has_ap | ||||||||
| and ap_for_grad is not None | ||||||||
| and "energy" in model_pred | ||||||||
| and "grad_aparam" in label | ||||||||
| ): | ||||||||
| find_grad_ap = label.get("find_grad_aparam", 0.0) | ||||||||
| pref_ap = ( | ||||||||
| self.limit_pref_ap + (self.start_pref_ap - self.limit_pref_ap) * coef | ||||||||
| ) * find_grad_ap | ||||||||
| energy_pred = model_pred["energy"] # [nf, 1] | ||||||||
| # 计算 d(sum_E)/d(aparam_raw),shape [nf, nloc, numb_aparam] | ||||||||
| grad_ap_pred = torch.autograd.grad( | ||||||||
| [energy_pred.sum()], | ||||||||
| [ap_for_grad], | ||||||||
| create_graph=True, # 使二阶梯度流回模型参数 | ||||||||
| retain_graph=True, # 保持计算图供 energy/force 损失反传 | ||||||||
| )[0] | ||||||||
| assert grad_ap_pred is not None | ||||||||
| grad_ap_label = label["grad_aparam"].to(grad_ap_pred.dtype) | ||||||||
| diff_ap = (grad_ap_label - grad_ap_pred).reshape(-1) | ||||||||
| l2_ap_loss = torch.mean(torch.square(diff_ap)) | ||||||||
| if not self.inference: | ||||||||
| more_loss["l2_grad_aparam_loss"] = self.display_if_exist( | ||||||||
| l2_ap_loss.detach(), find_grad_ap | ||||||||
| ) | ||||||||
| loss += (pref_ap * l2_ap_loss).to(GLOBAL_PT_FLOAT_PRECISION) | ||||||||
anyangml marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||
| more_loss["rmse_grad_aparam"] = self.display_if_exist( | ||||||||
| l2_ap_loss.sqrt().detach(), find_grad_ap | ||||||||
| ) | ||||||||
|
|
||||||||
| if not self.inference: | ||||||||
| more_loss["rmse"] = torch.sqrt(loss.detach()) | ||||||||
| return model_pred, loss, more_loss | ||||||||
|
|
@@ -482,6 +541,16 @@ def label_requirement(self) -> list[DataRequirementItem]: | |||||||
| default=1.0, | ||||||||
| ) | ||||||||
| ) | ||||||||
| if self.has_ap: | ||||||||
| label_requirement.append( | ||||||||
| DataRequirementItem( | ||||||||
| "grad_aparam", | ||||||||
| ndof=self.numb_aparam, | ||||||||
| atomic=True, | ||||||||
| must=False, | ||||||||
| high_prec=False, | ||||||||
| ) | ||||||||
| ) | ||||||||
| return label_requirement | ||||||||
|
|
||||||||
| def serialize(self) -> dict: | ||||||||
|
|
@@ -510,6 +579,8 @@ def serialize(self) -> dict: | |||||||
| "enable_atom_ener_coeff": self.enable_atom_ener_coeff, | ||||||||
| "start_pref_gf": self.start_pref_gf, | ||||||||
| "limit_pref_gf": self.limit_pref_gf, | ||||||||
| "start_pref_ap": self.start_pref_ap, | ||||||||
| "limit_pref_ap": self.limit_pref_ap, | ||||||||
|
||||||||
| "numb_generalized_coord": self.numb_generalized_coord, | ||||||||
|
||||||||
| "numb_generalized_coord": self.numb_generalized_coord, | |
| "numb_generalized_coord": self.numb_generalized_coord, | |
| "numb_aparam": self.numb_aparam, |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1690,6 +1690,11 @@ def get_loss( | |
| return EnergyHessianStdLoss(**loss_params) | ||
| elif loss_type == "ener": | ||
| loss_params["starter_learning_rate"] = start_lr | ||
| if ( | ||
| loss_params.get("start_pref_ap", 0.0) != 0.0 | ||
| or loss_params.get("limit_pref_ap", 0.0) != 0.0 | ||
| ): | ||
| loss_params["numb_aparam"] = _model.get_dim_aparam() | ||
|
Comment on lines
+1688
to
+1702
|
||
| return EnergyStdLoss(**loss_params) | ||
| elif loss_type == "dos": | ||
| loss_params["starter_learning_rate"] = start_lr | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.