diff --git a/src/tfmindi/datasets.py b/src/tfmindi/datasets.py index 6a7815b..f539a04 100644 --- a/src/tfmindi/datasets.py +++ b/src/tfmindi/datasets.py @@ -361,10 +361,22 @@ def load_motif_annotations( def load_motif_to_dbd(motif_annotations: pd.DataFrame) -> dict[str, str]: """ - Create motif-to-DNA-binding-domain mapping for human TFs. + Create motif-to-DNA-binding-domain mapping via evidence-hierarchy voting. - Takes motif annotations and maps motifs to their DNA-binding domains - based on TF annotations and human TF database information. + For each motif, walks evidence tiers in order:: + + Direct_annot → Orthology_annot → Motif_similarity_annot + → Motif_similarity_and_Orthology_annot + + At each tier, parses the comma-separated TF list, looks up each TF's DBD in + the human TF database, and collapses to a set (family-collapsed count). If + the set has size 1, that DBD wins. If the set has size > 1 AND we are at + the Direct_annot tier, the motif is labelled ``"Composite"`` (genuine + multi-TF dimer). If a lower tier is ambiguous, we fall through to the next + tier. If no tier yields a decision, the motif is absent from the returned + dict. + + TFs missing from the human TF database are silently skipped. Parameters ---------- @@ -374,7 +386,8 @@ def load_motif_to_dbd(motif_annotations: pd.DataFrame) -> dict[str, str]: Returns ------- dict[str, str] - Dictionary mapping motif IDs to DNA-binding domain names + Dictionary mapping motif IDs to DNA-binding domain names (or + ``"Composite"`` for multi-family dimer motifs). Examples -------- @@ -384,33 +397,42 @@ def load_motif_to_dbd(motif_annotations: pd.DataFrame) -> dict[str, str]: >>> print(motif_to_dbd["hocomoco__FOXO1_HUMAN.H11MO.0.A"]) 'Forkhead' """ - motif_to_tf = motif_annotations.copy() - - # Flatten all TF annotations into individual TF names - motif_to_tf = ( - motif_to_tf.apply(lambda row: ", ".join(row.dropna()), axis=1) - .str.split(", ") - .explode() - .reset_index() - .rename({0: "TF"}, axis=1) - ) - # Download human TF annotations with DNA-binding domains human_tf_annot = pd.read_csv( "https://humantfs.ccbr.utoronto.ca/download/v_1.01/DatabaseExtract_v_1.01.csv", index_col=0, - )[["HGNC symbol", "DBD"]] - - motif_to_tf = motif_to_tf.merge(right=human_tf_annot, how="left", left_on="TF", right_on="HGNC symbol") - - # For each motif, take the most common (mode) DBD annotation - motif_to_dbd = ( - motif_to_tf.dropna() - .groupby("MotifID")["DBD"] - .agg(lambda x: x.mode().iat[0]) # take the first mode if there's a tie - .reset_index() - ) + )[["HGNC symbol", "DBD"]].dropna() + tf_to_dbd: dict[str, str] = dict(zip(human_tf_annot["HGNC symbol"], human_tf_annot["DBD"], strict=True)) + + tiers = [ + "Direct_annot", + "Orthology_annot", + "Motif_similarity_annot", + "Motif_similarity_and_Orthology_annot", + ] - motif_to_dbd = motif_to_dbd.set_index("MotifID")["DBD"].to_dict() + motif_to_dbd: dict[str, str] = {} + + for motif_id, row in motif_annotations.iterrows(): + label: str | None = None + for tier in tiers: + cell = row.get(tier) + if cell is None or (isinstance(cell, float) and pd.isna(cell)): + continue + tfs = [t.strip() for t in str(cell).split(",") if t.strip()] + dbds = {tf_to_dbd[tf] for tf in tfs if tf in tf_to_dbd} + if not dbds: + continue + if len(dbds) == 1: + label = next(iter(dbds)) + break + # len(dbds) > 1 + if tier == "Direct_annot": + label = "Composite" + break + # lower tier ambiguous → fall through + continue + if label is not None: + motif_to_dbd[motif_id] = label return motif_to_dbd diff --git a/src/tfmindi/pp/seqlets.py b/src/tfmindi/pp/seqlets.py index c8a42db..ca9ecaa 100644 --- a/src/tfmindi/pp/seqlets.py +++ b/src/tfmindi/pp/seqlets.py @@ -441,6 +441,21 @@ def create_seqlet_adata( motif_ppms_typed = [ppm.astype(dtype) for ppm in motif_ppms] var_df["motif_ppm"] = motif_ppms_typed + # var_df is indexed by motif header name; annotations and DBD are keyed on + # the .cb file name stem and broadcast to every motif sharing that file. + # The header name must be globally unique so that .loc[name, col] returns + # a scalar — otherwise seqlet_dbd gets silently corrupted with Series + # values. Enforce that invariant loudly before any writes. + if not var_df.index.is_unique: + dupes = var_df.index[var_df.index.duplicated()].unique().tolist() + raise ValueError( + f"create_seqlet_adata: var_df index (motif header name) is not unique. " + f"Duplicates: {dupes[:10]}{'...' if len(dupes) > 10 else ''}. " + f"Two motifs from different .cb files share the same header — " + f"downstream per-seqlet DBD lookup would return a Series instead " + f"of a scalar." + ) + # Store motif annotations in .var if provided if motif_annotations is not None and motif_names is not None: # Add annotations for motifs that are present in the similarity matrix diff --git a/src/tfmindi/tl/cluster.py b/src/tfmindi/tl/cluster.py index 90c1a89..746a30a 100644 --- a/src/tfmindi/tl/cluster.py +++ b/src/tfmindi/tl/cluster.py @@ -13,8 +13,53 @@ from tfmindi.backends import get_backend, is_gpu_available +def _vote_dbd( + data: np.ndarray, + cols: np.ndarray, + dbd_values: np.ndarray, + top_k: int, + min_share: float, +) -> object: + """Top-K weighted DBD vote for one seqlet. + + Returns NaN if the row is empty, all top-K matches are NaN/Composite, or + the winner's share of the top-K vote falls below ``min_share``. + """ + if data.size == 0 or data.max() <= 0: + return np.nan + k = min(top_k, data.size) + top_local = np.argpartition(-data, k - 1)[:k] + top_cols = cols[top_local] + top_scores = data[top_local] + top_dbds = dbd_values[top_cols] + + totals: dict[str, float] = {} + for dbd, score in zip(top_dbds, top_scores, strict=False): + if dbd is None or (isinstance(dbd, float) and np.isnan(dbd)): + continue + if dbd == "Composite": + continue + totals[dbd] = totals.get(dbd, 0.0) + float(score) + + if not totals: + return np.nan + total_weight = sum(totals.values()) + if total_weight <= 0: + return np.nan + winner, winner_weight = max(totals.items(), key=lambda kv: kv[1]) + if winner_weight / total_weight < min_share: + return np.nan + return winner + + def cluster_seqlets( - adata: AnnData, resolution: float = 3.0, pca_svd_solver: str | None = None, *, recompute: bool = False + adata: AnnData, + resolution: float = 3.0, + pca_svd_solver: str | None = None, + *, + recompute: bool = False, + top_k_motifs: int = 5, + dbd_vote_min_share: float = 0.4, ) -> None: """ Perform complete clustering workflow including dimensionality reduction, clustering, and functional annotation. @@ -49,6 +94,11 @@ def cluster_seqlets( recompute If False (default), reuse existing PCA and neighborhood graph computations if available. If True, always recompute PCA, neighbors, and t-SNE from scratch. + top_k_motifs + Number of top TomTom matches used in the per-seqlet DBD vote (default: 5). + dbd_vote_min_share + Minimum fraction of the top-K weighted vote the winner must hold to be + accepted; otherwise the seqlet is labelled NaN (default: 0.4). Returns ------- @@ -149,19 +199,26 @@ def cluster_seqlets( adata.obs["mean_contrib"] = np.nan if "dbd" in adata.var.columns: - # find top motif for all seqlets at once - # For sparse matrices, argmax along axis=1 gives the column index of max value in each row from scipy import sparse + dbd_values = adata.var["dbd"].to_numpy(dtype=object) + seqlet_dbds: list[object] = [] + if sparse.issparse(adata.X): - # argmax on sparse matrix can return 2D array, ensure 1D - top_motif_indices = np.asarray(adata.X.argmax(axis=1)).flatten() + X = adata.X.tocsr() + for i in range(X.shape[0]): + start, end = X.indptr[i], X.indptr[i + 1] + data = X.data[start:end] + cols = X.indices[start:end] + seqlet_dbds.append(_vote_dbd(data, cols, dbd_values, top_k_motifs, dbd_vote_min_share)) else: - top_motif_indices = adata.X.argmax(axis=1) + X_dense = np.asarray(adata.X) + for i in range(X_dense.shape[0]): + row = X_dense[i] + nz = np.flatnonzero(row > 0) + seqlet_dbds.append(_vote_dbd(row[nz], nz, dbd_values, top_k_motifs, dbd_vote_min_share)) - top_motif_names = adata.var.index[top_motif_indices] - seqlet_dbds = [adata.var.loc[motif_name, "dbd"] for motif_name in top_motif_names] - adata.obs["seqlet_dbd"] = seqlet_dbds + adata.obs["seqlet_dbd"] = pd.Series(seqlet_dbds, index=adata.obs.index, dtype=object) else: print("Warning: No DBD annotations found in adata.var['dbd']") adata.obs["seqlet_dbd"] = np.nan @@ -189,8 +246,11 @@ def cluster_seqlets( # Test: One-tailed binomial test asking "Is k significantly greater than expected?" # This gives us a p-value for enrichment: P(X >= k | n, p) - # background probability. - dbd_to_probability = adata.var["dbd"].value_counts(normalize=True, dropna=False).to_dict() + # background probability. NaN is kept as legitimate background mass + # (unknown-DBD motifs still consume library slots); "Composite" is + # folded into NaN so it can never be selected as a cluster label. + _dbd_series = adata.var["dbd"].replace({"Composite": np.nan}) + dbd_to_probability = _dbd_series.value_counts(normalize=True, dropna=False).to_dict() def get_dbd_min_pval(df: pd.Series) -> str: """ @@ -207,7 +267,9 @@ def get_dbd_min_pval(df: pd.Series) -> str: # k = n_success # N = number of draws # dbd_to_p = prob of sucess - p_value = binom.sf(k - 1, N, dbd_to_probability[dbd]) + # Fall back to p=1.0 (no enrichment) for labels absent from the + # background — can only happen if a stray label slips through. + p_value = binom.sf(k - 1, N, dbd_to_probability.get(dbd, 1.0)) if p_value < min_pval: min_pval = p_value best_dbd = dbd diff --git a/tests/test_seqlet_assignment.py b/tests/test_seqlet_assignment.py new file mode 100644 index 0000000..c5ca925 --- /dev/null +++ b/tests/test_seqlet_assignment.py @@ -0,0 +1,241 @@ +"""Regression tests for the seqlet DBD assignment bugfix. + +See thoughts/shared/plans/2026-04-09-bugfix-seqlet-assignment.md +""" + +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest +import scipy.sparse as sp + +import tfmindi as tm + + +class TestLoadMotifToDbdHierarchy: + """Contract for the rewritten load_motif_to_dbd.""" + + @pytest.fixture + def fake_tf_to_dbd(self, monkeypatch): + """Patch the human TF CSV download with a synthetic in-memory table.""" + table = pd.DataFrame( + { + "HGNC symbol": [ + "GATA1", + "GATA2", + "GATA3", + "GATA4", + "GATA5", + "GATA6", + "NFIC", + "NFIA", + "E2F1", + "FOS", + "JUN", + "SOX2", + "POU5F1", + ], + "DBD": [ + "GATA", + "GATA", + "GATA", + "GATA", + "GATA", + "GATA", + "CTF/NF-I", + "CTF/NF-I", + "E2F", + "bZIP", + "bZIP", + "HMG/Sox", + "Homeodomain", + ], + } + ) + table.index = range(1, len(table) + 1) + + def fake_read_csv(url, index_col=0): + assert "humantfs" in url + return table + + monkeypatch.setattr("tfmindi.datasets.pd.read_csv", fake_read_csv) + return table + + def test_direct_single_family_wins(self, fake_tf_to_dbd): + annots = pd.DataFrame( + { + "Direct_annot": ["GATA1"], + "Motif_similarity_annot": [None], + "Orthology_annot": [None], + "Motif_similarity_and_Orthology_annot": [None], + }, + index=["motif_clean_gata"], + ) + annots.index.name = "MotifID" + out = tm.load_motif_to_dbd(annots) + assert out["motif_clean_gata"] == "GATA" + + def test_direct_multi_family_is_composite(self, fake_tf_to_dbd): + """The tfdimers__MD00378 failure case.""" + annots = pd.DataFrame( + { + "Direct_annot": ["NFIC, NFIA, E2F1"], + "Motif_similarity_annot": [None], + "Orthology_annot": [None], + "Motif_similarity_and_Orthology_annot": [None], + }, + index=["tfdimers__MD00378"], + ) + annots.index.name = "MotifID" + out = tm.load_motif_to_dbd(annots) + assert out["tfdimers__MD00378"] == "Composite" + + def test_similarity_family_collapse(self, fake_tf_to_dbd): + """GATA1..GATA6 listed only under similarity should still collapse to GATA.""" + annots = pd.DataFrame( + { + "Direct_annot": [None], + "Motif_similarity_annot": ["GATA1, GATA2, GATA3, GATA4, GATA5, GATA6"], + "Orthology_annot": [None], + "Motif_similarity_and_Orthology_annot": [None], + }, + index=["motif_gata_family"], + ) + annots.index.name = "MotifID" + out = tm.load_motif_to_dbd(annots) + assert out["motif_gata_family"] == "GATA" + + def test_orthology_beats_similarity(self, fake_tf_to_dbd): + """Orthology tier is consulted before similarity tier.""" + annots = pd.DataFrame( + { + "Direct_annot": [None], + "Motif_similarity_annot": ["FOS, JUN"], # bZIP + "Orthology_annot": ["GATA2"], # GATA + "Motif_similarity_and_Orthology_annot": [None], + }, + index=["motif_ortho_gata"], + ) + annots.index.name = "MotifID" + out = tm.load_motif_to_dbd(annots) + assert out["motif_ortho_gata"] == "GATA" + + def test_lower_tier_ambiguous_falls_through(self, fake_tf_to_dbd): + """Similarity listing TFs from many families and no other evidence → NaN.""" + annots = pd.DataFrame( + { + "Direct_annot": [None], + "Motif_similarity_annot": ["GATA1, FOS, SOX2, POU5F1"], + "Orthology_annot": [None], + "Motif_similarity_and_Orthology_annot": [None], + }, + index=["motif_noisy"], + ) + annots.index.name = "MotifID" + out = tm.load_motif_to_dbd(annots) + assert "motif_noisy" not in out or pd.isna(out.get("motif_noisy")) + + def test_no_evidence_means_absent(self, fake_tf_to_dbd): + annots = pd.DataFrame( + { + "Direct_annot": [None], + "Motif_similarity_annot": [None], + "Orthology_annot": [None], + "Motif_similarity_and_Orthology_annot": [None], + }, + index=["motif_empty"], + ) + annots.index.name = "MotifID" + out = tm.load_motif_to_dbd(annots) + assert "motif_empty" not in out or pd.isna(out.get("motif_empty")) + + def test_missing_tf_from_human_table_is_ignored(self, fake_tf_to_dbd): + """TFs not in the human TF CSV must not raise KeyError.""" + annots = pd.DataFrame( + { + "Direct_annot": ["FOO_NOT_HUMAN, GATA1"], + "Motif_similarity_annot": [None], + "Orthology_annot": [None], + "Motif_similarity_and_Orthology_annot": [None], + }, + index=["motif_partial"], + ) + annots.index.name = "MotifID" + out = tm.load_motif_to_dbd(annots) + assert out["motif_partial"] == "GATA" + + +class TestSeqletDbdTopKVote: + """Unit-level contract for the rewritten per-seqlet block in cluster_seqlets.""" + + def _build_minimal_adata(self, similarity_rows, var_dbds, n_filler: int = 30): + """Build a minimal AnnData that cluster_seqlets' full pipeline can consume. + + The provided ``similarity_rows`` become the leading rows of ``X``; the + remaining ``n_filler`` rows are filled with reproducible random noise so + that scanpy PCA/neighbors/tSNE/Leiden do not degenerate on a trivially + rank-deficient matrix. Only the leading rows should be asserted on. + """ + import anndata as ad + + rng = np.random.default_rng(0) + test_rows = np.asarray(similarity_rows, dtype=np.float32) + n_test, n_vars = test_rows.shape + filler = rng.uniform(0.1, 1.0, size=(n_filler, n_vars)).astype(np.float32) + full = np.vstack([test_rows, filler]) + X = sp.csr_array(full) + n_obs = X.shape[0] + obs = pd.DataFrame( + { + "seqlet_matrix": [np.zeros((4, 6), dtype=np.float32)] * n_obs, + }, + index=[str(i) for i in range(n_obs)], + ) + var = pd.DataFrame({"dbd": var_dbds}, index=[f"motif_{i}" for i in range(X.shape[1])]) + return ad.AnnData(X=X, obs=obs, var=var) + + def test_composite_top1_is_dropped_from_vote(self): + """Seqlet whose rank-1 is Composite but ranks 2-5 are bZIP → labelled bZIP.""" + row = [5.0, 4.0, 3.9, 3.8, 3.7, 0.0] # 6 motifs + var_dbds = ["Composite", "bZIP", "bZIP", "bZIP", "bZIP", "GATA"] + adata = self._build_minimal_adata([row], var_dbds) + tm.tl.cluster_seqlets(adata, resolution=1.0, top_k_motifs=5, dbd_vote_min_share=0.4) + assert adata.obs["seqlet_dbd"].iloc[0] == "bZIP" + + def test_nan_motifs_are_dropped_from_vote(self): + row = [10.0, 4.0, 3.9, 3.8, 3.7, 0.0] + var_dbds = [np.nan, "bZIP", "bZIP", "bZIP", "bZIP", "GATA"] + adata = self._build_minimal_adata([row], var_dbds) + tm.tl.cluster_seqlets(adata, resolution=1.0, top_k_motifs=5, dbd_vote_min_share=0.4) + assert adata.obs["seqlet_dbd"].iloc[0] == "bZIP" + + def test_empty_row_returns_nan(self): + row = [0.0] * 6 + var_dbds = ["GATA", "bZIP", "Ets", "bHLH", "Forkhead", "NR"] + adata = self._build_minimal_adata([row], var_dbds) + tm.tl.cluster_seqlets(adata, resolution=1.0, top_k_motifs=5, dbd_vote_min_share=0.4) + assert pd.isna(adata.obs["seqlet_dbd"].iloc[0]) + + def test_rejection_threshold_triggers_on_ties(self): + row = [1.0, 1.0, 1.0, 1.0, 1.0, 0.0] + var_dbds = ["GATA", "bZIP", "Ets", "bHLH", "Forkhead", "NR"] + adata = self._build_minimal_adata([row], var_dbds) + tm.tl.cluster_seqlets(adata, resolution=1.0, top_k_motifs=5, dbd_vote_min_share=0.4) + # Each family has 0.2 share → winner_share < 0.4 → NaN + assert pd.isna(adata.obs["seqlet_dbd"].iloc[0]) + + def test_clear_winner_above_threshold(self): + row = [5.0, 4.5, 1.0, 1.0, 1.0, 0.0] + var_dbds = ["bZIP", "bZIP", "GATA", "Ets", "Forkhead", "NR"] + adata = self._build_minimal_adata([row], var_dbds) + tm.tl.cluster_seqlets(adata, resolution=1.0, top_k_motifs=5, dbd_vote_min_share=0.4) + assert adata.obs["seqlet_dbd"].iloc[0] == "bZIP" + + def test_composite_excluded_from_cluster_background(self): + """'Composite' must not appear in cluster_dbd.""" + row = [5.0, 4.0, 0.0, 0.0, 0.0, 0.0] + var_dbds = ["Composite", "bZIP", "GATA", "Ets", "Forkhead", "NR"] + adata = self._build_minimal_adata([row] * 10, var_dbds) + tm.tl.cluster_seqlets(adata, resolution=1.0, top_k_motifs=5, dbd_vote_min_share=0.4) + assert "Composite" not in set(adata.obs["cluster_dbd"].dropna().unique()) diff --git a/tests/test_tl.py b/tests/test_tl.py index f2df4f3..bff1f73 100644 --- a/tests/test_tl.py +++ b/tests/test_tl.py @@ -60,13 +60,11 @@ def test_cluster_seqlets_output_structure(self, sample_clustered_adata): # Test mean_contrib calculation assert adata.obs["mean_contrib"].min() >= 0, "Mean contrib should be non-negative" - # Test seqlet_dbd assignment (should match top motif for each seqlet) - for i in range(min(5, adata.n_obs)): # Check first 5 seqlets - top_motif_idx = adata.X[i].argmax() - top_motif_name = adata.var.index[top_motif_idx] - expected_dbd = adata.var.loc[top_motif_name, "dbd"] - actual_dbd = adata.obs.iloc[i]["seqlet_dbd"] - assert actual_dbd == expected_dbd, f"DBD mismatch for seqlet {i}" + # Test seqlet_dbd assignment: every non-NaN label must come from the + # DBD vocabulary in var (excluding the synthetic Composite sentinel). + valid_dbds = set(adata.var["dbd"].dropna().unique()) - {"Composite"} + for seqlet_dbd in adata.obs["seqlet_dbd"].dropna().unique(): + assert seqlet_dbd in valid_dbds, f"Unknown seqlet_dbd: {seqlet_dbd}" # Test cluster_dbd consistency (all seqlets in same cluster should have same cluster_dbd) for cluster in adata.obs["leiden"].unique():