diff --git a/src/scenicplus/BASCA.py b/src/scenicplus/BASCA.py index 2c7e099..b31f0bc 100644 --- a/src/scenicplus/BASCA.py +++ b/src/scenicplus/BASCA.py @@ -81,6 +81,20 @@ def calc_error(vect, P, i, j): return np.sum(((vect - z) ** 2)[0: N]) +@numba.jit(nopython=True) +def calc_errors(calc_errors_matrix, calc_errors_is_cached, vect, P, i, j): + # Check if we have the cost_ab(vect, a, b) value cached. + if calc_errors_is_cached[i, j] == np.bool_(True): + return calc_errors_matrix[i, j] + + # Else calculate cost_ab(vect, a, b) and cache it for next time. + current_calc_error = calc_error(vect, P, i, j) + calc_errors_matrix[i, j] = current_calc_error + calc_errors_is_cached[i, j] = np.bool_(True) + + return current_calc_error + + @numba.jit(nopython=True) def moving_block_bootstrap(v): N = v.shape[0] @@ -177,13 +191,16 @@ def calc_scores(vect, P): # stores the index of the discontinuity with the maximum score for each step function ind_Q_max = np.zeros(N, dtype=np.int64) + calc_errors_matrix = np.zeros((N, N - 1), np.float64) + calc_errors_is_cached = np.zeros((N, N - 1), np.bool_) + for j in range(0, N): q_max = -1 ind_q_max = -1 for i in range(0, j + 1): # calculate jump height h = calc_jump_height(vect, P, i, j) - e = calc_error(vect, P, i, j) + e = calc_errors(calc_errors_matrix, calc_errors_is_cached, vect, P, i, j) q = h / e if q > q_max: q_max = q