diff --git a/deeptime/decomposition/deep/_vampnet.py b/deeptime/decomposition/deep/_vampnet.py index fbc3ab3e..afee96b5 100644 --- a/deeptime/decomposition/deep/_vampnet.py +++ b/deeptime/decomposition/deep/_vampnet.py @@ -213,9 +213,11 @@ def vamp_score(data: "torch.Tensor", data_lagged: "torch.Tensor", method='VAMP2' out = torch.pow(torch.norm(koopman, p='fro'), 2) elif method == 'VAMPE': c00, c0t, ctt = covariances(data, data_lagged, remove_mean=True) - c00_sqrt_inv = sym_inverse(c00, epsilon=epsilon, return_sqrt=True, mode=mode) - ctt_sqrt_inv = sym_inverse(ctt, epsilon=epsilon, return_sqrt=True, mode=mode) - koopman = multi_dot([c00_sqrt_inv, c0t, ctt_sqrt_inv]).t() + # in original paper of VAMPE, inv can be detached from gradient + c00_sqrt_inv = sym_inverse(c00, epsilon=epsilon, return_sqrt=True, mode=mode).detach() + ctt_sqrt_inv = sym_inverse(ctt, epsilon=epsilon, return_sqrt=True, mode=mode).detach() + # detach koopman, so that VAMPE is only depedent on the trace + koopman = multi_dot([c00_sqrt_inv, c0t, ctt_sqrt_inv]).t().detach() u, s, v = torch.svd(koopman) mask = s > epsilon