Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class FeatureTooltip {
}
}

export default function EmbeddingView({ brush, categoryColumn, categoryColumns, onFeatureClick, highlightedFeatureId, viewportState, onViewportChange, labels, features, selectedCategory, darkMode }) {
export default function EmbeddingView({ brush, categoryColumn, categoryColumns, onFeatureClick, highlightedFeatureId, viewportState, onViewportChange, labels, features, selectedCategory, darkMode, hiddenCategories }) {
const containerRef = useRef(null)
const viewRef = useRef(null)
const onFeatureClickRef = useRef(onFeatureClick)
Expand Down Expand Up @@ -267,7 +267,8 @@ export default function EmbeddingView({ brush, categoryColumn, categoryColumns,
if (!viewRef.current) return

let categoryColName = null
let colors = Array(50).fill(DEFAULT_COLOR)
const HIDDEN_COLOR = darkMode ? "#0a0a0a" : "#fafafa"
let colors = Array(50).fill(HIDDEN_COLOR)

if (categoryColumn && categoryColumn !== "none") {
const colInfo = categoryColumns?.find(c => c.name === categoryColumn)
Expand All @@ -278,6 +279,17 @@ export default function EmbeddingView({ brush, categoryColumn, categoryColumns,
} else if (colInfo.type === 'string') {
categoryColName = `${categoryColumn}_cat`
colors = CATEGORY_COLORS.slice(0, Math.max(colInfo.nUnique, 10))
// Map colors to match DENSE_RANK order, dim non-selected when filtering
if (hiddenCategories && hiddenCategories.size > 0 && features) {
const allCatNames = [...new Set(
features.map(f => f[categoryColumn]).filter(v => v != null)
)].sort()
colors = colors.map((c, i) => {
const name = allCatNames[i]
if (!name) return c
return !hiddenCategories.has(name) ? HIDDEN_COLOR : c
})
}
} else {
categoryColName = categoryColumn
colors = CATEGORY_COLORS.slice(0, Math.max(colInfo.nUnique, 10))
Expand All @@ -291,7 +303,7 @@ export default function EmbeddingView({ brush, categoryColumn, categoryColumns,
selection: null,
tooltip: null,
})
}, [categoryColumn, categoryColumns])
}, [categoryColumn, categoryColumns, hiddenCategories])

// Handle resize
useEffect(() => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,25 @@ const FeatureCard = forwardRef(function FeatureCard({ feature, isHighlighted, fo
lines.push('')
}

// GSEA enrichment section
const gseaCsvFields = [
{ key: 'gsea_overall_best', label: 'GSEA Overall Best' },
{ key: 'gsea_GO_Biological_Process', label: 'GSEA GO Biological Process' },
{ key: 'gsea_GO_Molecular_Function', label: 'GSEA GO Molecular Function' },
{ key: 'gsea_GO_Cellular_Component', label: 'GSEA GO Cellular Component' },
{ key: 'gsea_InterPro_Domains', label: 'GSEA InterPro Domains' },
{ key: 'gsea_Pfam_Domains', label: 'GSEA Pfam Domains' },
{ key: 'gsea_GO_Slim', label: 'GSEA GO Slim' },
]
const gseaLines = gseaCsvFields
.filter(({ key }) => feature[key] && feature[key] !== 'unlabeled')
.map(({ key, label }) => `${label},${feature[key]}`)
if (gseaLines.length > 0) {
lines.push('=== GSEA ENRICHMENT ===')
gseaLines.forEach(l => lines.push(l))
lines.push('')
}

// Examples section
if (examples && examples.length > 0) {
lines.push('=== ACTIVATION EXAMPLES ===')
Expand Down Expand Up @@ -657,6 +676,22 @@ const FeatureCard = forwardRef(function FeatureCard({ feature, isHighlighted, fo
if (ann.cpg) tags.push({ label: `CpG enriched`, color: '#fce4ec' })
if (ann.position) tags.push({ label: `N-terminal`, color: '#e8f5e9' })

// GSEA enrichment tags
const gseaFields = [
{ key: 'gsea_GO_Biological_Process', prefix: 'GO:BP', color: '#e8eaf6' },
{ key: 'gsea_GO_Molecular_Function', prefix: 'GO:MF', color: '#ede7f6' },
{ key: 'gsea_GO_Cellular_Component', prefix: 'GO:CC', color: '#e0f2f1' },
{ key: 'gsea_InterPro_Domains', prefix: 'InterPro', color: '#fff8e1' },
{ key: 'gsea_Pfam_Domains', prefix: 'Pfam', color: '#fbe9e7' },
{ key: 'gsea_GO_Slim', prefix: 'GO Slim', color: '#f1f8e9' },
]
for (const { key, prefix, color } of gseaFields) {
const val = feature[key]
if (val && val !== 'unlabeled' && val !== 'other') {
tags.push({ label: `${prefix}: ${val}`, color })
}
}

if (tags.length === 0) return null
return (
<div style={{ display: 'flex', flexWrap: 'wrap', gap: '4px', marginBottom: '10px' }}>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,54 @@ export default function FeatureDetailPage({ feature, examples, vocabLogits, feat
<VocabLogitChart logits={logits} />
</div>

{/* Gene-Level GSEA Enrichment */}
{(() => {
const gseaFields = [
{ key: 'gsea_GO_Biological_Process', label: 'GO Biological Process' },
{ key: 'gsea_GO_Molecular_Function', label: 'GO Molecular Function' },
{ key: 'gsea_GO_Cellular_Component', label: 'GO Cellular Component' },
{ key: 'gsea_InterPro_Domains', label: 'InterPro Domains' },
{ key: 'gsea_Pfam_Domains', label: 'Pfam Domains' },
{ key: 'gsea_GO_Slim', label: 'GO Slim' },
]
const gseaEntries = gseaFields
.map(({ key, label }) => ({ label, value: feature[key] }))
.filter(e => e.value && e.value !== 'unlabeled' && e.value !== 'other')
const overallBest = feature.gsea_overall_best
if (gseaEntries.length === 0 && (!overallBest || overallBest === 'unlabeled')) return null
return (
<div style={styles.section}>
<div style={styles.sectionTitle}>Gene-Level Enrichment (GSEA)</div>
<div style={styles.sectionSubtitle}>
Genes ranked by activation strength, tested against GO, InterPro, and Pfam databases.
</div>
{overallBest && overallBest !== 'unlabeled' && (
<div style={{
padding: '8px 12px', marginBottom: '8px', borderRadius: '6px',
background: 'var(--bg-card-expanded)', border: '1px solid var(--accent)',
fontSize: '13px', fontWeight: '600', color: 'var(--text-heading)',
}}>
Best: {overallBest}
</div>
)}
<div style={{ display: 'grid', gridTemplateColumns: '1fr 1fr', gap: '6px' }}>
{gseaEntries.map(({ label, value }) => (
<div key={label} style={{
padding: '6px 10px', borderRadius: '4px',
background: 'var(--bg-card)', border: '1px solid var(--border-card)',
fontSize: '11px',
}}>
<div style={{ color: 'var(--text-muted)', fontSize: '9px', fontWeight: '600', marginBottom: '2px' }}>
{label}
</div>
<div style={{ color: 'var(--text-primary)' }}>{value}</div>
</div>
))}
</div>
</div>
)
})()}

{/* Codon Annotations */}
<div style={styles.section}>
<div style={styles.sectionTitle}>Codon-Level Annotations</div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ def run_train(cfg: DictConfig, cache_dir: Path, output_dir: Path) -> None: # no
cmd.append("--normalize-input")
if t.get("max_grad_norm"):
cmd.extend(["--max-grad-norm", str(t.max_grad_norm)])
if t.get("lr_schedule", "constant") != "constant":
cmd.extend(["--lr-schedule", str(t.lr_schedule)])
if t.get("lr_min", 0.0) != 0.0:
cmd.extend(["--lr-min", str(t.lr_min)])
if t.get("lr_decay_steps"):
cmd.extend(["--lr-decay-steps", str(t.lr_decay_steps)])
if t.get("warmup_steps", 0) > 0:
cmd.extend(["--warmup-steps", str(t.warmup_steps)])

if t.wandb_enabled:
cmd.append("--wandb")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ train:
wandb_enabled: false
wandb_project: sae_codonfm_recipe
max_grad_norm: null
lr_schedule: constant
lr_min: 0.0
lr_decay_steps: null
warmup_steps: 0

# Eval
eval:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,12 @@ def parse_args(): # noqa: D103
p.add_argument(
"--auto-interp-workers", type=int, default=1, help="Number of parallel workers for LLM calls (default: 1)"
)
p.add_argument(
"--gsea-report",
type=str,
default=None,
help="Path to gene_enrichment_report.json — adds GSEA context to auto-interp prompts",
)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--device", type=str, default=None)
return p.parse_args()
Expand Down Expand Up @@ -569,6 +575,7 @@ def run_auto_interp(
llm_provider="anthropic",
llm_model=None,
num_workers=1,
gsea_context=None,
):
"""Run LLM auto-interpretation using precomputed top-K indices.

Expand Down Expand Up @@ -712,33 +719,56 @@ def interpret_feature(f):

examples_str = "\n".join(f" Seq {i + 1}: {ex}" for i, ex in enumerate(feature_examples.get(f, [])))

prompt = f"""This is a feature from a sparse autoencoder trained on a DNA codon language model (CodonFM).
Each token is a codon (3 nucleotides) that encodes an amino acid.
# Build GSEA enrichment context if available
gsea_str = ""
if gsea_context and f in gsea_context:
gsea_info = gsea_context[f]
gsea_lines = []
for db, entry in gsea_info.items():
if entry:
gsea_lines.append(f" {db}: {entry['term_name']} (FDR={entry['fdr']:.4f})")
if gsea_lines:
gsea_str = "\n\nGene-level GSEA enrichment (genes ranked by activation, tested against annotation databases):\n"
gsea_str += "\n".join(gsea_lines)

prompt = f"""Analyze this sparse autoencoder feature from a DNA codon language model (CodonFM) to determine what predicts its activation pattern. Each token is a codon (3 nucleotides encoding one amino acid).

Top promoted codons (decoder logits): {pos_str}
Top suppressed codons: {neg_str}

Top activating sequences (***highlighted*** = high activation):
Each sequence may include metadata in brackets: gene name, data source (ClinVar=germline variants, COSMIC=somatic cancer mutations), pathogenicity label, PhyloP conservation score, variant info (ref>alt codon at position), and model effect score (more negative = higher predicted impact).
{examples_str}
Top activating sequences (***highlighted*** = high activation codons):
Metadata in brackets may include: gene name, data source (ClinVar/COSMIC), pathogenicity, PhyloP conservation, variant info (ref>alt codon at position), model effect score.
{examples_str}{gsea_str}

In 1 short sentence starting with "Fires on", describe what biological pattern this feature detects.
Consider: amino acid identity, specific codon choice, codon usage bias, positional context, CpG sites, wobble position patterns, and any variant/clinical metadata patterns you observe.
Analyze what predicts high vs low activation for this feature. This description should be concise but sufficient to predict activation levels on unseen codon sequences. The feature could be specific to a gene family, a codon usage pattern, a sequence motif, a functional role, a structural domain, etc.

Focus on:
- Which codons and amino acids are associated with high vs low activation, and whether specific synonymous codon choices matter
- Where in the gene sequence activation occurs (N-terminal, C-terminal, or throughout)
- What gene-level functional annotations (from GSEA enrichment if provided) characterize the top-activating genes
- Whether codon usage bias, CpG content, wobble position patterns, or GC content are relevant
- Any variant/clinical metadata patterns (pathogenicity, conservation, mutation impact)

Your description will be used to predict activation on held-out sequences, so only highlight factors relevant for prediction.

Format your response as:
Label: <one short phrase>
Description: <2-3 sentences starting with "The activation patterns are characterized by:">
Label: <one concise phrase summarizing what this feature detects>
Confidence: <0.00 to 1.00>"""

try:
response = client.generate(prompt)
text = response.text.strip()

label = None
description = None
confidence = 0.0

for line in text.split("\n"):
if line.startswith("Label:"):
label = line.replace("Label:", "").strip()
elif line.startswith("Description:"):
description = line.replace("Description:", "").strip()
elif line.startswith("Confidence:"):
try:
confidence = float(line.replace("Confidence:", "").strip())
Expand All @@ -749,21 +779,24 @@ def interpret_feature(f):
if not label:
label = f"Feature {f}"

return f, label, confidence
return f, label, confidence, description
except Exception as e:
print(f" Warning: auto-interp failed for feature {f}: {e}")
return f, f"Feature {f}", 0.0
return f, f"Feature {f}", 0.0, None

interpretations = {}
confidences = {}
descriptions = {}
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {executor.submit(interpret_feature, f): f for f in feature_indices}
for future in tqdm(as_completed(futures), total=len(feature_indices), desc=" Auto-interp"):
f, label, confidence = future.result()
f, label, confidence, description = future.result()
interpretations[f] = label
confidences[f] = confidence
if description:
descriptions[f] = description

return interpretations, confidences
return interpretations, confidences, descriptions


# ── Build summary labels ─────────────────────────────────────────────
Expand Down Expand Up @@ -927,6 +960,29 @@ def main(): # noqa: D103
auto_interp_labels[k_int] = {"label": v, "confidence": 0.0}
print(f" Loaded {len(auto_interp_labels)} existing interpretations")

# Load GSEA context if provided
gsea_context = None
if args.gsea_report:
gsea_report_path = Path(args.gsea_report)
if gsea_report_path.exists():
print(f" Loading GSEA report from {gsea_report_path}...")
with open(gsea_report_path) as f:
gsea_data = json.load(f)
gsea_context = {}
for fl in gsea_data.get("per_feature", []):
feat_idx = fl["feature_idx"]
per_db = {}
for db, entry in fl.get("best_per_database", {}).items():
if entry is not None:
per_db[db] = entry
if fl.get("overall_best"):
per_db["overall_best"] = fl["overall_best"]
if per_db:
gsea_context[feat_idx] = per_db
print(f" GSEA context loaded for {len(gsea_context)} features")
else:
print(f" WARNING: GSEA report not found at {gsea_report_path}")

if args.auto_interp:
print("\n[3/3] Auto-interpretation (LLM)...")
alive_features = [f for f in range(n_features) if f in codon_annotations]
Expand All @@ -942,7 +998,7 @@ def main(): # noqa: D103

if todo_features:
print(f" Running auto-interp on {len(todo_features)} features ({len(auto_interp_labels)} already done)")
new_labels, new_confidences = run_auto_interp(
new_labels, new_confidences, new_descriptions = run_auto_interp(
sae,
vocab_logits,
inference,
Expand All @@ -957,11 +1013,13 @@ def main(): # noqa: D103
llm_provider=args.llm_provider,
llm_model=args.llm_model,
num_workers=args.auto_interp_workers,
gsea_context=gsea_context,
)
for f in new_labels:
auto_interp_labels[f] = {
"label": new_labels[f],
"confidence": new_confidences[f],
"description": new_descriptions.get(f),
}
with open(auto_interp_ckpt, "w") as f:
json.dump(auto_interp_labels, f, indent=2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from codonfm_sae.eval.gene_enrichment import ( # noqa: E402
ANNOTATION_DATABASES,
GeneEnrichmentReport,
detect_gene_families,
download_obo_files,
rollup_go_slim,
run_gene_enrichment,
Expand Down Expand Up @@ -499,6 +500,18 @@ def main():
gsea_time = time.time() - t0
print(f"\n GSEA completed in {gsea_time:.1f}s")

# 6b. Detect gene families
print(" Detecting gene families...")
gene_families = detect_gene_families(gene_activations)
print(f" {len(gene_families)} features with dominant gene family")

# Update report label columns with gene families
from codonfm_sae.eval.gene_enrichment import build_feature_label_columns

report.feature_label_columns = build_feature_label_columns(
report.per_feature, report.n_features_total, gene_families=gene_families
)

# 7. Save results (before GO Slim so we don't lose GSEA work on failure)
print("\n" + "=" * 60)
print("SAVING RESULTS")
Expand All @@ -524,7 +537,9 @@ def main():
# Rebuild label columns with GO Slim info
from codonfm_sae.eval.gene_enrichment import build_feature_label_columns

report.feature_label_columns = build_feature_label_columns(report.per_feature, report.n_features_total)
report.feature_label_columns = build_feature_label_columns(
report.per_feature, report.n_features_total, gene_families=gene_families
)

n_slim = sum(1 for fl in report.per_feature if fl.go_slim_name is not None)
slim_names = {fl.go_slim_name for fl in report.per_feature if fl.go_slim_name is not None}
Expand Down
Loading
Loading