diff --git a/modelforge/potential/schnet.py b/modelforge/potential/schnet.py index 8383f09d..7144d406 100644 --- a/modelforge/potential/schnet.py +++ b/modelforge/potential/schnet.py @@ -294,11 +294,11 @@ def forward( # Generate interaction filters based on radial basis functions W_ij = self.filter_network(f_ij.squeeze(1)) - W_ij = W_ij * f_ij_cutoff # Shape: [n_pairs, number_of_filters] + W_ij = torch.einsum("nf,n->nf", W_ij, f_ij_cutoff.squeeze(-1)) # Perform continuous-filter convolution x_j = atomic_embedding[idx_j] - x_ij = x_j * W_ij # Element-wise multiplication + x_ij = torch.einsum("nk,nk->nk", W_ij, x_j) out = torch.zeros_like(atomic_embedding).scatter_add_( 0, idx_i.unsqueeze(-1).expand_as(x_ij), x_ij