diff --git a/src/tfmindi/pp/seqlets.py b/src/tfmindi/pp/seqlets.py index c8a42db..d9e625c 100644 --- a/src/tfmindi/pp/seqlets.py +++ b/src/tfmindi/pp/seqlets.py @@ -531,11 +531,15 @@ def recursive_seqlets(X, threshold=0.01, min_seqlet_len=4, max_seqlet_len=25, ad THIS FUNCTION IS A DIRECT COPY FROM THE TANGERMEME REPOSITORY FROM JACOB SCHREIBER. We do a direct copy here since we only need this function and we want to avoid the heavy torch installation. + NOTE: Currently only *positive* seqlets will be identified. The easiest way + to get negative seqlets is to run this on the absolute value of the + attribution values and then re-extract the attribution sums given the + boundaries. + This algorithm identifies spans of high attribution characters, called seqlets, using a simple approach derived from the Tomtom/FIMO algorithms. First, distributions of attribution sums are created for all potential - seqlet lengths by discretizing the sum, with one set of distributions for - positive attribution values and one for negative attribution values. Then, + seqlet lengths by discretizing the attribution sum into integers. Then, CDFs are calculated for each distribution (or, more specifically, 1-CDFs). Finally, p-values are calculated via lookup to these 1-CDFs for all potential CDFs, yielding a (n_positions, n_lengths) matrix of p-values. @@ -618,78 +622,164 @@ def recursive_seqlets(X, threshold=0.01, min_seqlet_len=4, max_seqlet_len=25, ad @numba.njit def _recursive_seqlets(X, threshold=0.01, min_seqlet_len=4, max_seqlet_len=25, additional_flanks=0, n_bins=1000): - """Call seqlets recursively using the Tangermeme algorithm. + """ + Call seqlets recursively using the Tangermeme algorithm with separate null distributions for positive and negative scores. This algorithm has four steps. - (1) Convert attribution scores into integer bins and calculate a histogram - (2) Convert these histograms into null distributions across lengths - (3) Use the null distributions to calculate p-values for each possible length + (1) Convert attribution scores into integer bins and calculate separate histograms for positive and negative scores + (2) Convert these histograms into separate null distributions across lengths + (3) Use the appropriate null distribution to calculate p-values for each possible length based on sign (4) Decode this matrix of p-values to find the longest seqlets """ n, l = X.shape - m = n * l + + # Clamp attribution values to prevent extreme outliers from dominating binning + xmax_orig, xmin_orig = X.max(), X.min() + clamp_max = xmin_orig + 0.95 * (xmax_orig - xmin_orig) # 95th percentile + clamp_min = xmin_orig + 0.05 * (xmax_orig - xmin_orig) # 5th percentile + X_clamped = np.clip(X, clamp_min, clamp_max) ### - # Step 1: Calculate a histogram of binned scores + # Step 1: Calculate separate histograms for positive and negative scores ### - xmax, xmin = X.max(), X.min() - bin_width = (xmax - xmin) / (n_bins - 1) + # Find min/max for clamped dataset + xmax, xmin = X_clamped.max(), X_clamped.min() + + # Calculate separate bin ranges for positive and negative scores + # For positive: use [0, xmax] if xmax > 0, else use [xmin, 0] + if xmax > 0: + xmax_pos = xmax + xmin_pos = max(0.0, xmin) # Start from 0 or higher + else: + xmax_pos = 0.0 + xmin_pos = xmin + + # For negative: use absolute values [0, |xmin|] if xmin < 0, else use [0, xmax] + if xmin < 0: + xmax_neg = abs(xmin) + xmin_neg = 0.0 + else: + xmax_neg = xmax + xmin_neg = 0.0 + + # Prevent division by zero + if xmax_pos == xmin_pos: + xmax_pos = xmin_pos + 1e-6 + if xmax_neg == xmin_neg: + xmax_neg = xmin_neg + 1e-6 - f = np.zeros(n_bins, dtype=np.float64) + bin_width_pos = (xmax_pos - xmin_pos) / (n_bins - 1) + bin_width_neg = (xmax_neg - xmin_neg) / (n_bins - 1) + + # Build distributions and count in single pass through data + f_pos = np.zeros(n_bins, dtype=np.float64) + f_neg = np.zeros(n_bins, dtype=np.float64) + m_pos = 0 + m_neg = 0 for i in range(n): for j in range(l): - x_bin = math.floor((X[i, j] - xmin) / bin_width) - f[x_bin] += 1 - - f = f / m + val = X_clamped[i, j] + if val >= 0: + # Positive distribution + m_pos += 1 + x_bin = math.floor((val - xmin_pos) / bin_width_pos) + x_bin = max(0, min(x_bin, n_bins - 1)) + f_pos[x_bin] += 1 + else: + # Negative distribution (use absolute value) + m_neg += 1 + abs_val = abs(val) + x_bin = math.floor((abs_val - xmin_neg) / bin_width_neg) + x_bin = max(0, min(x_bin, n_bins - 1)) + f_neg[x_bin] += 1 + + # Handle edge case where all values are one sign + if m_pos == 0: + m_pos = 1 # Avoid division by zero + if m_neg == 0: + m_neg = 1 # Avoid division by zero + + # Normalize to probabilities + f_pos = f_pos / m_pos + f_neg = f_neg / m_neg ### - # Step 2: Calculate null distributions across lengths + # Step 2: Calculate separate null distributions across lengths ### - scores = np.zeros((max_seqlet_len + 1, n_bins * max_seqlet_len), dtype=np.float64) - scores[1, :n_bins] = f + # Positive distributions + scores_pos = np.zeros((max_seqlet_len + 1, n_bins * max_seqlet_len), dtype=np.float64) + scores_pos[1, :n_bins] = f_pos + rcdfs_pos = np.zeros_like(scores_pos) + rcdfs_pos[:, 0] = 1.0 - rcdfs = np.zeros_like(scores) - rcdfs[:, 0] = 1.0 + # Negative distributions + scores_neg = np.zeros((max_seqlet_len + 1, n_bins * max_seqlet_len), dtype=np.float64) + scores_neg[1, :n_bins] = f_neg + rcdfs_neg = np.zeros_like(scores_neg) + rcdfs_neg[:, 0] = 1.0 + # Build convolutions for positive scores for seqlet_len in range(2, max_seqlet_len + 1): for i in range(n_bins * (seqlet_len - 1)): for j in range(n_bins): - scores[seqlet_len, i + j] += scores[seqlet_len - 1, i] * f[j] + scores_pos[seqlet_len, i + j] += scores_pos[seqlet_len - 1, i] * f_pos[j] for i in range(1, n_bins * seqlet_len): - rcdfs[seqlet_len, i] = max(rcdfs[seqlet_len, i - 1] - scores[seqlet_len, i], 0) + rcdfs_pos[seqlet_len, i] = max(rcdfs_pos[seqlet_len, i - 1] - scores_pos[seqlet_len, i], 0) + + # Build convolutions for negative scores + for seqlet_len in range(2, max_seqlet_len + 1): + for i in range(n_bins * (seqlet_len - 1)): + for j in range(n_bins): + scores_neg[seqlet_len, i + j] += scores_neg[seqlet_len - 1, i] * f_neg[j] + + for i in range(1, n_bins * seqlet_len): + rcdfs_neg[seqlet_len, i] = max(rcdfs_neg[seqlet_len, i - 1] - scores_neg[seqlet_len, i], 0) ### # Step 3: Calculate p-values given these 1-CDFs ### - X_csum = np.zeros((n, l + 1)) + X_csum = np.zeros((n, l + 1)) # Cumulative sum for each example for i in range(n): for j in range(l): X_csum[i, j + 1] = X_csum[i, j] + X[i, j] ### - # Step 4: Decode p-values into seqlets + # Step 4: Decode p-values into seqlets using appropriate distributions ### seqlets = [] for i in range(n): + # calculate p-values for every possible seqlet position and length p_value = np.ones((max_seqlet_len + 1, l), dtype=np.float64) p_value[:min_seqlet_len] = 0 p_value[:, -min_seqlet_len] = 1 for seqlet_len in range(min_seqlet_len, max_seqlet_len + 1): for k in range(l - seqlet_len + 1): - x_ = X_csum[i, k + seqlet_len] - X_csum[i, k] - x_ = math.floor((x_ - xmin * seqlet_len) / bin_width) + attr_sum = X_csum[i, k + seqlet_len] - X_csum[i, k] + + # Choose appropriate distribution based on attribution sum sign + if attr_sum >= 0: + # Use positive distribution + x_bin = math.floor((attr_sum - xmin_pos * seqlet_len) / bin_width_pos) + x_bin = max(0, min(x_bin, n_bins * seqlet_len - 1)) + p_val = rcdfs_pos[seqlet_len, x_bin] + else: + # Use negative distribution with absolute value + abs_attr_sum = abs(attr_sum) + x_bin = math.floor((abs_attr_sum - xmin_neg * seqlet_len) / bin_width_neg) + x_bin = max(0, min(x_bin, n_bins * seqlet_len - 1)) + p_val = rcdfs_neg[seqlet_len, x_bin] - p_value[seqlet_len, k] = max(rcdfs[seqlet_len, x_], p_value[seqlet_len - 1, k]) + # Assign highest of p-values of internal spans + p_value[seqlet_len, k] = max(p_val, p_value[seqlet_len - 1, k]) # Iteratively identify spans, from longest to shortest, that satisfy the # recursive p-value threshold. @@ -697,18 +787,22 @@ def _recursive_seqlets(X, threshold=0.01, min_seqlet_len=4, max_seqlet_len=25, a seqlet_len = max_seqlet_len - j while True: + # find the position with the lowest p-value for this length start = p_value[seqlet_len].argmin() p = p_value[seqlet_len, start] - p_value[seqlet_len, start] = 1 + p_value[seqlet_len, start] = 1 # avoid finding this again + # if p-value is above threshold, we're done with this length if p >= threshold: break + # check if all internal spans also satisfy the threshold for k in range(1, seqlet_len): if p_value[seqlet_len - k, start + k] >= threshold: - break + break # reject this position else: + # valid seqlet found, mark all overlapping positions as used for end in range(start, min(start + seqlet_len, l - 1)): p_value[:, end] = 1 @@ -718,3 +812,13 @@ def _recursive_seqlets(X, threshold=0.01, min_seqlet_len=4, max_seqlet_len=25, a seqlets.append((i, start, end, attr, p)) return seqlets + + +if __name__ == "__main__": + # test _recursive_seqlets + sample_oh = np.load("tests/data/sample_oh.npz")["oh"] + sample_contrib = np.load("tests/data/sample_contrib.npz")["contrib"] + + seqlets_df, seqlet_matrices = extract_seqlets(sample_contrib, sample_oh, threshold=0.05) + print(seqlets_df.head()) + print(len(seqlet_matrices)) diff --git a/tests/test_pp.py b/tests/test_pp.py index a438fca..2bb02fd 100644 --- a/tests/test_pp.py +++ b/tests/test_pp.py @@ -28,7 +28,26 @@ def test_extract_seqlets_real_data(self, sample_contrib_data, sample_oh_data): """Test extract_seqlets with real data.""" seqlet_df, seqlet_matrices = tm.pp.extract_seqlets(sample_contrib_data, sample_oh_data) - assert len(seqlet_df) == len(seqlet_matrices) == 227 + # Debug: Check positive vs negative seqlets + positive_seqlets = seqlet_df[seqlet_df["attribution"] > 0] + negative_seqlets = seqlet_df[seqlet_df["attribution"] < 0] + zero_seqlets = seqlet_df[seqlet_df["attribution"] == 0] + + print(f"Total seqlets found: {len(seqlet_df)}") + print(f"Positive seqlets: {len(positive_seqlets)}") + print(f"Negative seqlets: {len(negative_seqlets)}") + print(f"Zero attribution seqlets: {len(zero_seqlets)}") + + if len(negative_seqlets) > 0: + print( + f"Negative attribution range: [{negative_seqlets['attribution'].min():.3f}, {negative_seqlets['attribution'].max():.3f}]" + ) + if len(positive_seqlets) > 0: + print( + f"Positive attribution range: [{positive_seqlets['attribution'].min():.3f}, {positive_seqlets['attribution'].max():.3f}]" + ) + + assert len(seqlet_df) == len(seqlet_matrices) assert isinstance(seqlet_df, pd.DataFrame) assert isinstance(seqlet_matrices, list) @@ -40,6 +59,9 @@ def test_extract_seqlets_real_data(self, sample_contrib_data, sample_oh_data): for matrix in seqlet_matrices: assert np.all(matrix >= -1) and np.all(matrix <= 1) + # Verify we found some negative seqlets + assert len(negative_seqlets) > 0, "Should find some negative seqlets in real data" + class TestCalculateMotifSimilarity: """Test calculate_motif_similarity function."""