From 64e4cfafb4ec6af46e66e382aec2b6408814b476 Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 6 Nov 2024 16:10:31 +0100 Subject: [PATCH] schnet with einsum --- modelforge/potential/schnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modelforge/potential/schnet.py b/modelforge/potential/schnet.py index 8383f09d..7144d406 100644 --- a/modelforge/potential/schnet.py +++ b/modelforge/potential/schnet.py @@ -294,11 +294,11 @@ def forward( # Generate interaction filters based on radial basis functions W_ij = self.filter_network(f_ij.squeeze(1)) - W_ij = W_ij * f_ij_cutoff # Shape: [n_pairs, number_of_filters] + W_ij = torch.einsum("nf,n->nf", W_ij, f_ij_cutoff.squeeze(-1)) # Perform continuous-filter convolution x_j = atomic_embedding[idx_j] - x_ij = x_j * W_ij # Element-wise multiplication + x_ij = torch.einsum("nk,nk->nk", W_ij, x_j) out = torch.zeros_like(atomic_embedding).scatter_add_( 0, idx_i.unsqueeze(-1).expand_as(x_ij), x_ij