Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 133 additions & 29 deletions src/tfmindi/pp/seqlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,15 @@ def recursive_seqlets(X, threshold=0.01, min_seqlet_len=4, max_seqlet_len=25, ad
THIS FUNCTION IS A DIRECT COPY FROM THE TANGERMEME REPOSITORY FROM JACOB SCHREIBER.
We do a direct copy here since we only need this function and we want to avoid the heavy torch installation.

NOTE: Currently only *positive* seqlets will be identified. The easiest way
to get negative seqlets is to run this on the absolute value of the
attribution values and then re-extract the attribution sums given the
boundaries.

This algorithm identifies spans of high attribution characters, called
seqlets, using a simple approach derived from the Tomtom/FIMO algorithms.
First, distributions of attribution sums are created for all potential
seqlet lengths by discretizing the sum, with one set of distributions for
positive attribution values and one for negative attribution values. Then,
seqlet lengths by discretizing the attribution sum into integers. Then,
CDFs are calculated for each distribution (or, more specifically, 1-CDFs).
Finally, p-values are calculated via lookup to these 1-CDFs for all
potential CDFs, yielding a (n_positions, n_lengths) matrix of p-values.
Expand Down Expand Up @@ -618,97 +622,187 @@ def recursive_seqlets(X, threshold=0.01, min_seqlet_len=4, max_seqlet_len=25, ad

@numba.njit
def _recursive_seqlets(X, threshold=0.01, min_seqlet_len=4, max_seqlet_len=25, additional_flanks=0, n_bins=1000):
"""Call seqlets recursively using the Tangermeme algorithm.
"""
Call seqlets recursively using the Tangermeme algorithm with separate null distributions for positive and negative scores.

This algorithm has four steps.

(1) Convert attribution scores into integer bins and calculate a histogram
(2) Convert these histograms into null distributions across lengths
(3) Use the null distributions to calculate p-values for each possible length
(1) Convert attribution scores into integer bins and calculate separate histograms for positive and negative scores
(2) Convert these histograms into separate null distributions across lengths
(3) Use the appropriate null distribution to calculate p-values for each possible length based on sign
(4) Decode this matrix of p-values to find the longest seqlets
"""
n, l = X.shape
m = n * l

# Clamp attribution values to prevent extreme outliers from dominating binning
xmax_orig, xmin_orig = X.max(), X.min()
clamp_max = xmin_orig + 0.95 * (xmax_orig - xmin_orig) # 95th percentile
clamp_min = xmin_orig + 0.05 * (xmax_orig - xmin_orig) # 5th percentile
X_clamped = np.clip(X, clamp_min, clamp_max)

###
# Step 1: Calculate a histogram of binned scores
# Step 1: Calculate separate histograms for positive and negative scores
###

xmax, xmin = X.max(), X.min()
bin_width = (xmax - xmin) / (n_bins - 1)
# Find min/max for clamped dataset
xmax, xmin = X_clamped.max(), X_clamped.min()

# Calculate separate bin ranges for positive and negative scores
# For positive: use [0, xmax] if xmax > 0, else use [xmin, 0]
if xmax > 0:
xmax_pos = xmax
xmin_pos = max(0.0, xmin) # Start from 0 or higher
else:
xmax_pos = 0.0
xmin_pos = xmin

# For negative: use absolute values [0, |xmin|] if xmin < 0, else use [0, xmax]
if xmin < 0:
xmax_neg = abs(xmin)
xmin_neg = 0.0
else:
xmax_neg = xmax
xmin_neg = 0.0

# Prevent division by zero
if xmax_pos == xmin_pos:
xmax_pos = xmin_pos + 1e-6
if xmax_neg == xmin_neg:
xmax_neg = xmin_neg + 1e-6

f = np.zeros(n_bins, dtype=np.float64)
bin_width_pos = (xmax_pos - xmin_pos) / (n_bins - 1)
bin_width_neg = (xmax_neg - xmin_neg) / (n_bins - 1)

# Build distributions and count in single pass through data
f_pos = np.zeros(n_bins, dtype=np.float64)
f_neg = np.zeros(n_bins, dtype=np.float64)
m_pos = 0
m_neg = 0

for i in range(n):
for j in range(l):
x_bin = math.floor((X[i, j] - xmin) / bin_width)
f[x_bin] += 1

f = f / m
val = X_clamped[i, j]
if val >= 0:
# Positive distribution
m_pos += 1
x_bin = math.floor((val - xmin_pos) / bin_width_pos)
x_bin = max(0, min(x_bin, n_bins - 1))
f_pos[x_bin] += 1
else:
# Negative distribution (use absolute value)
m_neg += 1
abs_val = abs(val)
x_bin = math.floor((abs_val - xmin_neg) / bin_width_neg)
x_bin = max(0, min(x_bin, n_bins - 1))
f_neg[x_bin] += 1

# Handle edge case where all values are one sign
if m_pos == 0:
m_pos = 1 # Avoid division by zero
if m_neg == 0:
m_neg = 1 # Avoid division by zero

# Normalize to probabilities
f_pos = f_pos / m_pos
f_neg = f_neg / m_neg

###
# Step 2: Calculate null distributions across lengths
# Step 2: Calculate separate null distributions across lengths
###

scores = np.zeros((max_seqlet_len + 1, n_bins * max_seqlet_len), dtype=np.float64)
scores[1, :n_bins] = f
# Positive distributions
scores_pos = np.zeros((max_seqlet_len + 1, n_bins * max_seqlet_len), dtype=np.float64)
scores_pos[1, :n_bins] = f_pos
rcdfs_pos = np.zeros_like(scores_pos)
rcdfs_pos[:, 0] = 1.0

rcdfs = np.zeros_like(scores)
rcdfs[:, 0] = 1.0
# Negative distributions
scores_neg = np.zeros((max_seqlet_len + 1, n_bins * max_seqlet_len), dtype=np.float64)
scores_neg[1, :n_bins] = f_neg
rcdfs_neg = np.zeros_like(scores_neg)
rcdfs_neg[:, 0] = 1.0

# Build convolutions for positive scores
for seqlet_len in range(2, max_seqlet_len + 1):
for i in range(n_bins * (seqlet_len - 1)):
for j in range(n_bins):
scores[seqlet_len, i + j] += scores[seqlet_len - 1, i] * f[j]
scores_pos[seqlet_len, i + j] += scores_pos[seqlet_len - 1, i] * f_pos[j]

for i in range(1, n_bins * seqlet_len):
rcdfs[seqlet_len, i] = max(rcdfs[seqlet_len, i - 1] - scores[seqlet_len, i], 0)
rcdfs_pos[seqlet_len, i] = max(rcdfs_pos[seqlet_len, i - 1] - scores_pos[seqlet_len, i], 0)

# Build convolutions for negative scores
for seqlet_len in range(2, max_seqlet_len + 1):
for i in range(n_bins * (seqlet_len - 1)):
for j in range(n_bins):
scores_neg[seqlet_len, i + j] += scores_neg[seqlet_len - 1, i] * f_neg[j]

for i in range(1, n_bins * seqlet_len):
rcdfs_neg[seqlet_len, i] = max(rcdfs_neg[seqlet_len, i - 1] - scores_neg[seqlet_len, i], 0)

###
# Step 3: Calculate p-values given these 1-CDFs
###

X_csum = np.zeros((n, l + 1))
X_csum = np.zeros((n, l + 1)) # Cumulative sum for each example
for i in range(n):
for j in range(l):
X_csum[i, j + 1] = X_csum[i, j] + X[i, j]

###
# Step 4: Decode p-values into seqlets
# Step 4: Decode p-values into seqlets using appropriate distributions
###

seqlets = []

for i in range(n):
# calculate p-values for every possible seqlet position and length
p_value = np.ones((max_seqlet_len + 1, l), dtype=np.float64)
p_value[:min_seqlet_len] = 0
p_value[:, -min_seqlet_len] = 1

for seqlet_len in range(min_seqlet_len, max_seqlet_len + 1):
for k in range(l - seqlet_len + 1):
x_ = X_csum[i, k + seqlet_len] - X_csum[i, k]
x_ = math.floor((x_ - xmin * seqlet_len) / bin_width)
attr_sum = X_csum[i, k + seqlet_len] - X_csum[i, k]

# Choose appropriate distribution based on attribution sum sign
if attr_sum >= 0:
# Use positive distribution
x_bin = math.floor((attr_sum - xmin_pos * seqlet_len) / bin_width_pos)
x_bin = max(0, min(x_bin, n_bins * seqlet_len - 1))
p_val = rcdfs_pos[seqlet_len, x_bin]
else:
# Use negative distribution with absolute value
abs_attr_sum = abs(attr_sum)
x_bin = math.floor((abs_attr_sum - xmin_neg * seqlet_len) / bin_width_neg)
x_bin = max(0, min(x_bin, n_bins * seqlet_len - 1))
p_val = rcdfs_neg[seqlet_len, x_bin]

p_value[seqlet_len, k] = max(rcdfs[seqlet_len, x_], p_value[seqlet_len - 1, k])
# Assign highest of p-values of internal spans
p_value[seqlet_len, k] = max(p_val, p_value[seqlet_len - 1, k])

# Iteratively identify spans, from longest to shortest, that satisfy the
# recursive p-value threshold.
for j in range(max_seqlet_len - min_seqlet_len + 1):
seqlet_len = max_seqlet_len - j

while True:
# find the position with the lowest p-value for this length
start = p_value[seqlet_len].argmin()
p = p_value[seqlet_len, start]
p_value[seqlet_len, start] = 1
p_value[seqlet_len, start] = 1 # avoid finding this again

# if p-value is above threshold, we're done with this length
if p >= threshold:
break

# check if all internal spans also satisfy the threshold
for k in range(1, seqlet_len):
if p_value[seqlet_len - k, start + k] >= threshold:
break
break # reject this position

else:
# valid seqlet found, mark all overlapping positions as used
for end in range(start, min(start + seqlet_len, l - 1)):
p_value[:, end] = 1

Expand All @@ -718,3 +812,13 @@ def _recursive_seqlets(X, threshold=0.01, min_seqlet_len=4, max_seqlet_len=25, a
seqlets.append((i, start, end, attr, p))

return seqlets


if __name__ == "__main__":
# test _recursive_seqlets
sample_oh = np.load("tests/data/sample_oh.npz")["oh"]
sample_contrib = np.load("tests/data/sample_contrib.npz")["contrib"]

seqlets_df, seqlet_matrices = extract_seqlets(sample_contrib, sample_oh, threshold=0.05)
print(seqlets_df.head())
print(len(seqlet_matrices))
24 changes: 23 additions & 1 deletion tests/test_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,26 @@ def test_extract_seqlets_real_data(self, sample_contrib_data, sample_oh_data):
"""Test extract_seqlets with real data."""
seqlet_df, seqlet_matrices = tm.pp.extract_seqlets(sample_contrib_data, sample_oh_data)

assert len(seqlet_df) == len(seqlet_matrices) == 227
# Debug: Check positive vs negative seqlets
positive_seqlets = seqlet_df[seqlet_df["attribution"] > 0]
negative_seqlets = seqlet_df[seqlet_df["attribution"] < 0]
zero_seqlets = seqlet_df[seqlet_df["attribution"] == 0]

print(f"Total seqlets found: {len(seqlet_df)}")
print(f"Positive seqlets: {len(positive_seqlets)}")
print(f"Negative seqlets: {len(negative_seqlets)}")
print(f"Zero attribution seqlets: {len(zero_seqlets)}")

if len(negative_seqlets) > 0:
print(
f"Negative attribution range: [{negative_seqlets['attribution'].min():.3f}, {negative_seqlets['attribution'].max():.3f}]"
)
if len(positive_seqlets) > 0:
print(
f"Positive attribution range: [{positive_seqlets['attribution'].min():.3f}, {positive_seqlets['attribution'].max():.3f}]"
)

assert len(seqlet_df) == len(seqlet_matrices)

assert isinstance(seqlet_df, pd.DataFrame)
assert isinstance(seqlet_matrices, list)
Expand All @@ -40,6 +59,9 @@ def test_extract_seqlets_real_data(self, sample_contrib_data, sample_oh_data):
for matrix in seqlet_matrices:
assert np.all(matrix >= -1) and np.all(matrix <= 1)

# Verify we found some negative seqlets
assert len(negative_seqlets) > 0, "Should find some negative seqlets in real data"


class TestCalculateMotifSimilarity:
"""Test calculate_motif_similarity function."""
Expand Down
Loading