diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/App.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/App.jsx
index 2488ba2870..b3c448d879 100644
--- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/App.jsx
+++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/App.jsx
@@ -191,6 +191,7 @@ export default function App({ title = "SAE Feature Explorer", subtitle = "Explor
const [mosaicReady, setMosaicReady] = useState(false)
const [categoryColumns, setCategoryColumns] = useState([])
const [selectedCategory, setSelectedCategory] = useState('mean_variant_1bcdwt')
+ const [hiddenCategories, setHiddenCategories] = useState(new Set())
const [clickedFeatureId, setClickedFeatureId] = useState(null)
const [clusterLabels, setClusterLabels] = useState(null)
const [vocabLogits, setVocabLogits] = useState(null)
@@ -213,6 +214,7 @@ export default function App({ title = "SAE Feature Explorer", subtitle = "Explor
const endOfListRef = useRef(null)
const searchSource = useRef({ source: 'search' })
const editedSource = useRef({ source: 'edited' })
+ const legendSource = useRef({ source: 'legend' })
const loadingMoreRef = useRef(false)
// Lazy-load examples for a single feature from DuckDB (feature_examples VIEW)
@@ -480,12 +482,33 @@ export default function App({ title = "SAE Feature Explorer", subtitle = "Explor
if (['x', 'y', 'feature_id', 'top_example_idx'].includes(col.name)) continue
if (col.type === 'VARCHAR') {
+ const isGsea = col.name.startsWith('gsea_')
+ const maxUnique = isGsea ? Infinity : 50
const cardinalityResult = await vg.coordinator().query(`
- SELECT COUNT(DISTINCT "${col.name}") as n_unique FROM features WHERE "${col.name}" IS NOT NULL
+ SELECT COUNT(DISTINCT "${col.name}") as n_unique FROM features WHERE "${col.name}" IS NOT NULL AND "${col.name}" != 'unlabeled'
`)
const nUnique = cardinalityResult.toArray()[0]?.n_unique ?? 0
- if (nUnique > 0 && nUnique <= 50) {
- detectedCategories.push({ name: col.name, type: 'string', nUnique })
+ if (nUnique > 0 && nUnique <= maxUnique) {
+ // For high-cardinality GSEA columns, collapse to top 20 + "other"
+ if (isGsea && nUnique > 20) {
+ await vg.coordinator().exec(`
+ CREATE OR REPLACE TABLE features AS
+ SELECT * REPLACE (
+ CASE
+ WHEN "${col.name}" IS NULL OR "${col.name}" = 'unlabeled' THEN 'unlabeled'
+ WHEN "${col.name}" IN (
+ SELECT "${col.name}" FROM features
+ WHERE "${col.name}" IS NOT NULL AND "${col.name}" != 'unlabeled'
+ GROUP BY "${col.name}" ORDER BY COUNT(*) DESC LIMIT 20
+ ) THEN "${col.name}"
+ ELSE 'other'
+ END AS "${col.name}"
+ ) FROM features
+ `)
+ detectedCategories.push({ name: col.name, type: 'string', nUnique: 22 })
+ } else {
+ detectedCategories.push({ name: col.name, type: 'string', nUnique })
+ }
}
} else if (col.type === 'BIGINT' || col.type === 'INTEGER') {
if (col.name.includes('cluster') || col.name.includes('category') || col.name.includes('group')) {
@@ -564,7 +587,12 @@ export default function App({ title = "SAE Feature Explorer", subtitle = "Explor
SELECT * FROM read_parquet('${examplesUrl}')
`)
- // Load features from the features table (which has labels)
+ // Load features from the features table (which has labels + category columns)
+ const categorySelectCols = detectedCategories
+ .filter(c => c.type === 'string' || c.type === 'integer')
+ .map(c => `"${c.name}"`)
+ .join(', ')
+ const extraSelect = categorySelectCols ? `, ${categorySelectCols}` : ''
const featuresResult = await vg.coordinator().query(`
SELECT
feature_id,
@@ -573,18 +601,27 @@ export default function App({ title = "SAE Feature Explorer", subtitle = "Explor
max_activation,
x,
y
+ ${extraSelect}
FROM features
ORDER BY feature_id
`)
- const loadedFeatures = featuresResult.toArray().map(row => ({
- feature_id: row.feature_id,
- label: row.label,
- description: row.label,
- activation_freq: row.activation_freq,
- max_activation: row.max_activation,
- x: row.x,
- y: row.y,
- }))
+ const loadedFeatures = featuresResult.toArray().map(row => {
+ const f = {
+ feature_id: row.feature_id,
+ label: row.label,
+ description: row.label,
+ activation_freq: row.activation_freq,
+ max_activation: row.max_activation,
+ x: row.x,
+ y: row.y,
+ }
+ for (const col of detectedCategories) {
+ if (col.type === 'string' || col.type === 'integer') {
+ f[col.name] = row[col.name]
+ }
+ }
+ return f
+ })
setFeatures(loadedFeatures)
// Generate cluster labels from DuckDB (non-fatal if cluster_id doesn't exist)
@@ -777,6 +814,7 @@ export default function App({ title = "SAE Feature Explorer", subtitle = "Explor
setSelectedFeatureIds(null)
setSearchTerm('')
setClickedFeatureId(null)
+ setHiddenCategories(new Set())
// Reset viewport to the auto-fit view captured on first load
if (initialViewportRef.current) {
setViewportState({ ...initialViewportRef.current })
@@ -927,6 +965,41 @@ export default function App({ title = "SAE Feature Explorer", subtitle = "Explor
}
}, [showEditedOnly, mosaicReady, features])
+ // Update Mosaic crossfilter when legend selection changes
+ useEffect(() => {
+ if (!brushRef.current || !mosaicReady) return
+
+ const selection = brushRef.current
+
+ if (hiddenCategories.size > 0 && selectedCategory && selectedCategory !== 'none') {
+ const colInfo = categoryColumns.find(c => c.name === selectedCategory)
+ if (colInfo && (colInfo.type === 'string' || colInfo.type === 'integer')) {
+ const values = Array.from(hiddenCategories).map(v => `'${v.replace(/'/g, "''")}'`).join(',')
+ const predicateSql = `"${selectedCategory}" IN (${values})`
+
+ try {
+ selection.update({
+ source: legendSource.current,
+ predicate: predicateSql,
+ value: Array.from(hiddenCategories).join(',')
+ })
+ } catch (err) {
+ console.warn('Legend filter update failed:', err)
+ }
+ }
+ } else {
+ try {
+ selection.update({
+ source: legendSource.current,
+ predicate: null,
+ value: null
+ })
+ } catch (err) {
+ // Ignore
+ }
+ }
+ }, [hiddenCategories, selectedCategory, mosaicReady, categoryColumns])
+
// Handle search - updates both Mosaic crossfilter (for UMAP/histograms) and local state (for cards)
const handleSearchChange = useCallback((e) => {
const term = e.target.value
@@ -1124,6 +1197,7 @@ export default function App({ title = "SAE Feature Explorer", subtitle = "Explor
onChange={e => {
const val = e.target.value
setSelectedCategory(val)
+ setHiddenCategories(new Set())
setHistMetric3(val)
setClickedFeatureId(null)
setCardResetKey(k => k + 1)
@@ -1166,41 +1240,136 @@ export default function App({ title = "SAE Feature Explorer", subtitle = "Explor
features={features}
selectedCategory={selectedCategory}
darkMode={darkMode}
+ hiddenCategories={hiddenCategories}
/>
)}
{selectedCategory && selectedCategory !== 'none' && (() => {
const colInfo = categoryColumns.find(c => c.name === selectedCategory)
- if (!colInfo || colInfo.type !== 'sequential') return null
- const colors = [
- "#c359ef", "#9525C6", "#0046a4", "#0074DF", "#3f8500",
- "#76B900", "#ef9100", "#F9C500", "#ff8181", "#EF2020"
- ]
- const vals = features
- .map(f => f[selectedCategory])
- .filter(v => v != null && !isNaN(v))
- const minVal = vals.length > 0 ? Math.min(...vals) : 0
- const maxVal = vals.length > 0 ? Math.max(...vals) : 1
- const fmt = (v) => Math.abs(v) >= 100 ? v.toFixed(0) : Math.abs(v) >= 1 ? v.toFixed(1) : v.toFixed(3)
- return (
-
-
{fmt(maxVal)}
+ if (!colInfo) return null
+
+ if (colInfo.type === 'sequential') {
+ const colors = [
+ "#c359ef", "#9525C6", "#0046a4", "#0074DF", "#3f8500",
+ "#76B900", "#ef9100", "#F9C500", "#ff8181", "#EF2020"
+ ]
+ const vals = features
+ .map(f => f[selectedCategory])
+ .filter(v => v != null && !isNaN(v))
+ const minVal = vals.length > 0 ? Math.min(...vals) : 0
+ const maxVal = vals.length > 0 ? Math.max(...vals) : 1
+ const fmt = (v) => Math.abs(v) >= 100 ? v.toFixed(0) : Math.abs(v) >= 1 ? v.toFixed(1) : v.toFixed(3)
+ return (
+
+
{fmt(maxVal)}
+
+
{fmt(minVal)}
+
+ {selectedCategory.replace(/_/g, ' ')}
+
+
+ )
+ }
+
+ if (colInfo.type === 'string' || colInfo.type === 'integer') {
+ const catColors = [
+ "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
+ "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf",
+ "#aec7e8", "#ffbb78", "#98df8a", "#ff9896", "#c5b0d5",
+ "#c49c94", "#f7b6d2", "#c7c7c7", "#dbdb8d", "#9edae5"
+ ]
+ // Count occurrences of each category value, sorted alphabetically
+ // (matching DENSE_RANK ORDER BY which is alphabetical)
+ const counts = {}
+ for (const f of features) {
+ const val = f[selectedCategory]
+ if (val != null && val !== '') {
+ counts[val] = (counts[val] || 0) + 1
+ }
+ }
+ // Sort alphabetically to match dense_rank ordering
+ const sortedCategories = Object.keys(counts).sort()
+ return (
-
{fmt(minVal)}
-
- {selectedCategory.replace(/_/g, ' ')}
-
-
- )
+
+ {selectedCategory.replace(/_/g, ' ').replace('gsea ', '')}
+
+ {sortedCategories.map((cat, i) => {
+ const hasFilter = hiddenCategories.size > 0
+ const isHidden = hasFilter && !hiddenCategories.has(cat)
+ return (
+ {
+ if (e.metaKey || e.ctrlKey) {
+ // Cmd/Ctrl+click: toggle this category in the selection
+ setHiddenCategories(prev => {
+ const next = new Set(prev)
+ if (next.has(cat)) {
+ next.delete(cat)
+ // If nothing left selected, clear filter
+ return next.size === 0 ? new Set() : next
+ } else {
+ next.add(cat)
+ return next
+ }
+ })
+ } else {
+ // Regular click: solo this category (or clear if already solo'd)
+ setHiddenCategories(prev => {
+ if (prev.size === 1 && prev.has(cat)) return new Set()
+ return new Set([cat])
+ })
+ }
+ }}
+ style={{
+ display: 'flex', alignItems: 'center', gap: '5px', padding: '2px 0',
+ cursor: 'pointer', opacity: isHidden ? 0.15 : 1,
+ userSelect: 'none',
+ }}
+ >
+
+
+ {cat}
+
+
+ {counts[cat]}
+
+
+ )
+ })}
+
+ )
+ }
+
+ return null
})()}
diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/EmbeddingView.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/EmbeddingView.jsx
index ecb9ee0e02..bc14226257 100644
--- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/EmbeddingView.jsx
+++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/EmbeddingView.jsx
@@ -64,7 +64,7 @@ class FeatureTooltip {
}
}
-export default function EmbeddingView({ brush, categoryColumn, categoryColumns, onFeatureClick, highlightedFeatureId, viewportState, onViewportChange, labels, features, selectedCategory, darkMode }) {
+export default function EmbeddingView({ brush, categoryColumn, categoryColumns, onFeatureClick, highlightedFeatureId, viewportState, onViewportChange, labels, features, selectedCategory, darkMode, hiddenCategories }) {
const containerRef = useRef(null)
const viewRef = useRef(null)
const onFeatureClickRef = useRef(onFeatureClick)
@@ -267,7 +267,8 @@ export default function EmbeddingView({ brush, categoryColumn, categoryColumns,
if (!viewRef.current) return
let categoryColName = null
- let colors = Array(50).fill(DEFAULT_COLOR)
+ const HIDDEN_COLOR = darkMode ? "#0a0a0a" : "#fafafa"
+ let colors = Array(50).fill(HIDDEN_COLOR)
if (categoryColumn && categoryColumn !== "none") {
const colInfo = categoryColumns?.find(c => c.name === categoryColumn)
@@ -278,6 +279,17 @@ export default function EmbeddingView({ brush, categoryColumn, categoryColumns,
} else if (colInfo.type === 'string') {
categoryColName = `${categoryColumn}_cat`
colors = CATEGORY_COLORS.slice(0, Math.max(colInfo.nUnique, 10))
+ // Map colors to match DENSE_RANK order, dim non-selected when filtering
+ if (hiddenCategories && hiddenCategories.size > 0 && features) {
+ const allCatNames = [...new Set(
+ features.map(f => f[categoryColumn]).filter(v => v != null)
+ )].sort()
+ colors = colors.map((c, i) => {
+ const name = allCatNames[i]
+ if (!name) return c
+ return !hiddenCategories.has(name) ? HIDDEN_COLOR : c
+ })
+ }
} else {
categoryColName = categoryColumn
colors = CATEGORY_COLORS.slice(0, Math.max(colInfo.nUnique, 10))
@@ -291,7 +303,7 @@ export default function EmbeddingView({ brush, categoryColumn, categoryColumns,
selection: null,
tooltip: null,
})
- }, [categoryColumn, categoryColumns])
+ }, [categoryColumn, categoryColumns, hiddenCategories])
// Handle resize
useEffect(() => {
diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/FeatureCard.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/FeatureCard.jsx
index a432993bbf..d101b14225 100644
--- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/FeatureCard.jsx
+++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/FeatureCard.jsx
@@ -349,6 +349,25 @@ const FeatureCard = forwardRef(function FeatureCard({ feature, isHighlighted, fo
lines.push('')
}
+ // GSEA enrichment section
+ const gseaCsvFields = [
+ { key: 'gsea_overall_best', label: 'GSEA Overall Best' },
+ { key: 'gsea_GO_Biological_Process', label: 'GSEA GO Biological Process' },
+ { key: 'gsea_GO_Molecular_Function', label: 'GSEA GO Molecular Function' },
+ { key: 'gsea_GO_Cellular_Component', label: 'GSEA GO Cellular Component' },
+ { key: 'gsea_InterPro_Domains', label: 'GSEA InterPro Domains' },
+ { key: 'gsea_Pfam_Domains', label: 'GSEA Pfam Domains' },
+ { key: 'gsea_GO_Slim', label: 'GSEA GO Slim' },
+ ]
+ const gseaLines = gseaCsvFields
+ .filter(({ key }) => feature[key] && feature[key] !== 'unlabeled')
+ .map(({ key, label }) => `${label},${feature[key]}`)
+ if (gseaLines.length > 0) {
+ lines.push('=== GSEA ENRICHMENT ===')
+ gseaLines.forEach(l => lines.push(l))
+ lines.push('')
+ }
+
// Examples section
if (examples && examples.length > 0) {
lines.push('=== ACTIVATION EXAMPLES ===')
@@ -657,6 +676,22 @@ const FeatureCard = forwardRef(function FeatureCard({ feature, isHighlighted, fo
if (ann.cpg) tags.push({ label: `CpG enriched`, color: '#fce4ec' })
if (ann.position) tags.push({ label: `N-terminal`, color: '#e8f5e9' })
+ // GSEA enrichment tags
+ const gseaFields = [
+ { key: 'gsea_GO_Biological_Process', prefix: 'GO:BP', color: '#e8eaf6' },
+ { key: 'gsea_GO_Molecular_Function', prefix: 'GO:MF', color: '#ede7f6' },
+ { key: 'gsea_GO_Cellular_Component', prefix: 'GO:CC', color: '#e0f2f1' },
+ { key: 'gsea_InterPro_Domains', prefix: 'InterPro', color: '#fff8e1' },
+ { key: 'gsea_Pfam_Domains', prefix: 'Pfam', color: '#fbe9e7' },
+ { key: 'gsea_GO_Slim', prefix: 'GO Slim', color: '#f1f8e9' },
+ ]
+ for (const { key, prefix, color } of gseaFields) {
+ const val = feature[key]
+ if (val && val !== 'unlabeled' && val !== 'other') {
+ tags.push({ label: `${prefix}: ${val}`, color })
+ }
+ }
+
if (tags.length === 0) return null
return (
diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/FeatureDetailPage.jsx b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/FeatureDetailPage.jsx
index 695b9181ab..5775f42b41 100644
--- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/FeatureDetailPage.jsx
+++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/codon_dashboard/src/FeatureDetailPage.jsx
@@ -369,6 +369,54 @@ export default function FeatureDetailPage({ feature, examples, vocabLogits, feat
+ {/* Gene-Level GSEA Enrichment */}
+ {(() => {
+ const gseaFields = [
+ { key: 'gsea_GO_Biological_Process', label: 'GO Biological Process' },
+ { key: 'gsea_GO_Molecular_Function', label: 'GO Molecular Function' },
+ { key: 'gsea_GO_Cellular_Component', label: 'GO Cellular Component' },
+ { key: 'gsea_InterPro_Domains', label: 'InterPro Domains' },
+ { key: 'gsea_Pfam_Domains', label: 'Pfam Domains' },
+ { key: 'gsea_GO_Slim', label: 'GO Slim' },
+ ]
+ const gseaEntries = gseaFields
+ .map(({ key, label }) => ({ label, value: feature[key] }))
+ .filter(e => e.value && e.value !== 'unlabeled' && e.value !== 'other')
+ const overallBest = feature.gsea_overall_best
+ if (gseaEntries.length === 0 && (!overallBest || overallBest === 'unlabeled')) return null
+ return (
+
+
Gene-Level Enrichment (GSEA)
+
+ Genes ranked by activation strength, tested against GO, InterPro, and Pfam databases.
+
+ {overallBest && overallBest !== 'unlabeled' && (
+
+ Best: {overallBest}
+
+ )}
+
+ {gseaEntries.map(({ label, value }) => (
+
+
+ {label}
+
+
{value}
+
+ ))}
+
+
+ )
+ })()}
+
{/* Codon Annotations */}
Codon-Level Annotations
diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/run.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/run.py
index da1db78626..dca3646dd8 100644
--- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/run.py
+++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/run.py
@@ -133,6 +133,14 @@ def run_train(cfg: DictConfig, cache_dir: Path, output_dir: Path) -> None: # no
cmd.append("--normalize-input")
if t.get("max_grad_norm"):
cmd.extend(["--max-grad-norm", str(t.max_grad_norm)])
+ if t.get("lr_schedule", "constant") != "constant":
+ cmd.extend(["--lr-schedule", str(t.lr_schedule)])
+ if t.get("lr_min", 0.0) != 0.0:
+ cmd.extend(["--lr-min", str(t.lr_min)])
+ if t.get("lr_decay_steps"):
+ cmd.extend(["--lr-decay-steps", str(t.lr_decay_steps)])
+ if t.get("warmup_steps", 0) > 0:
+ cmd.extend(["--warmup-steps", str(t.warmup_steps)])
if t.wandb_enabled:
cmd.append("--wandb")
diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/run_configs/config.yaml b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/run_configs/config.yaml
index ec350a9764..36be6b27b5 100644
--- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/run_configs/config.yaml
+++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/run_configs/config.yaml
@@ -42,6 +42,10 @@ train:
wandb_enabled: false
wandb_project: sae_codonfm_recipe
max_grad_norm: null
+ lr_schedule: constant
+ lr_min: 0.0
+ lr_decay_steps: null
+ warmup_steps: 0
# Eval
eval:
diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/analyze.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/analyze.py
index 06b05fe982..077d5ac3fb 100644
--- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/analyze.py
+++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/analyze.py
@@ -227,6 +227,12 @@ def parse_args(): # noqa: D103
p.add_argument(
"--auto-interp-workers", type=int, default=1, help="Number of parallel workers for LLM calls (default: 1)"
)
+ p.add_argument(
+ "--gsea-report",
+ type=str,
+ default=None,
+ help="Path to gene_enrichment_report.json — adds GSEA context to auto-interp prompts",
+ )
p.add_argument("--seed", type=int, default=42)
p.add_argument("--device", type=str, default=None)
return p.parse_args()
@@ -569,6 +575,7 @@ def run_auto_interp(
llm_provider="anthropic",
llm_model=None,
num_workers=1,
+ gsea_context=None,
):
"""Run LLM auto-interpretation using precomputed top-K indices.
@@ -712,21 +719,41 @@ def interpret_feature(f):
examples_str = "\n".join(f" Seq {i + 1}: {ex}" for i, ex in enumerate(feature_examples.get(f, [])))
- prompt = f"""This is a feature from a sparse autoencoder trained on a DNA codon language model (CodonFM).
-Each token is a codon (3 nucleotides) that encodes an amino acid.
+ # Build GSEA enrichment context if available
+ gsea_str = ""
+ if gsea_context and f in gsea_context:
+ gsea_info = gsea_context[f]
+ gsea_lines = []
+ for db, entry in gsea_info.items():
+ if entry:
+ gsea_lines.append(f" {db}: {entry['term_name']} (FDR={entry['fdr']:.4f})")
+ if gsea_lines:
+ gsea_str = "\n\nGene-level GSEA enrichment (genes ranked by activation, tested against annotation databases):\n"
+ gsea_str += "\n".join(gsea_lines)
+
+ prompt = f"""Analyze this sparse autoencoder feature from a DNA codon language model (CodonFM) to determine what predicts its activation pattern. Each token is a codon (3 nucleotides encoding one amino acid).
Top promoted codons (decoder logits): {pos_str}
Top suppressed codons: {neg_str}
-Top activating sequences (***highlighted*** = high activation):
-Each sequence may include metadata in brackets: gene name, data source (ClinVar=germline variants, COSMIC=somatic cancer mutations), pathogenicity label, PhyloP conservation score, variant info (ref>alt codon at position), and model effect score (more negative = higher predicted impact).
-{examples_str}
+Top activating sequences (***highlighted*** = high activation codons):
+Metadata in brackets may include: gene name, data source (ClinVar/COSMIC), pathogenicity, PhyloP conservation, variant info (ref>alt codon at position), model effect score.
+{examples_str}{gsea_str}
-In 1 short sentence starting with "Fires on", describe what biological pattern this feature detects.
-Consider: amino acid identity, specific codon choice, codon usage bias, positional context, CpG sites, wobble position patterns, and any variant/clinical metadata patterns you observe.
+Analyze what predicts high vs low activation for this feature. This description should be concise but sufficient to predict activation levels on unseen codon sequences. The feature could be specific to a gene family, a codon usage pattern, a sequence motif, a functional role, a structural domain, etc.
+
+Focus on:
+- Which codons and amino acids are associated with high vs low activation, and whether specific synonymous codon choices matter
+- Where in the gene sequence activation occurs (N-terminal, C-terminal, or throughout)
+- What gene-level functional annotations (from GSEA enrichment if provided) characterize the top-activating genes
+- Whether codon usage bias, CpG content, wobble position patterns, or GC content are relevant
+- Any variant/clinical metadata patterns (pathogenicity, conservation, mutation impact)
+
+Your description will be used to predict activation on held-out sequences, so only highlight factors relevant for prediction.
Format your response as:
-Label:
+Description: <2-3 sentences starting with "The activation patterns are characterized by:">
+Label:
Confidence: <0.00 to 1.00>"""
try:
@@ -734,11 +761,14 @@ def interpret_feature(f):
text = response.text.strip()
label = None
+ description = None
confidence = 0.0
for line in text.split("\n"):
if line.startswith("Label:"):
label = line.replace("Label:", "").strip()
+ elif line.startswith("Description:"):
+ description = line.replace("Description:", "").strip()
elif line.startswith("Confidence:"):
try:
confidence = float(line.replace("Confidence:", "").strip())
@@ -749,21 +779,24 @@ def interpret_feature(f):
if not label:
label = f"Feature {f}"
- return f, label, confidence
+ return f, label, confidence, description
except Exception as e:
print(f" Warning: auto-interp failed for feature {f}: {e}")
- return f, f"Feature {f}", 0.0
+ return f, f"Feature {f}", 0.0, None
interpretations = {}
confidences = {}
+ descriptions = {}
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {executor.submit(interpret_feature, f): f for f in feature_indices}
for future in tqdm(as_completed(futures), total=len(feature_indices), desc=" Auto-interp"):
- f, label, confidence = future.result()
+ f, label, confidence, description = future.result()
interpretations[f] = label
confidences[f] = confidence
+ if description:
+ descriptions[f] = description
- return interpretations, confidences
+ return interpretations, confidences, descriptions
# ── Build summary labels ─────────────────────────────────────────────
@@ -927,6 +960,29 @@ def main(): # noqa: D103
auto_interp_labels[k_int] = {"label": v, "confidence": 0.0}
print(f" Loaded {len(auto_interp_labels)} existing interpretations")
+ # Load GSEA context if provided
+ gsea_context = None
+ if args.gsea_report:
+ gsea_report_path = Path(args.gsea_report)
+ if gsea_report_path.exists():
+ print(f" Loading GSEA report from {gsea_report_path}...")
+ with open(gsea_report_path) as f:
+ gsea_data = json.load(f)
+ gsea_context = {}
+ for fl in gsea_data.get("per_feature", []):
+ feat_idx = fl["feature_idx"]
+ per_db = {}
+ for db, entry in fl.get("best_per_database", {}).items():
+ if entry is not None:
+ per_db[db] = entry
+ if fl.get("overall_best"):
+ per_db["overall_best"] = fl["overall_best"]
+ if per_db:
+ gsea_context[feat_idx] = per_db
+ print(f" GSEA context loaded for {len(gsea_context)} features")
+ else:
+ print(f" WARNING: GSEA report not found at {gsea_report_path}")
+
if args.auto_interp:
print("\n[3/3] Auto-interpretation (LLM)...")
alive_features = [f for f in range(n_features) if f in codon_annotations]
@@ -942,7 +998,7 @@ def main(): # noqa: D103
if todo_features:
print(f" Running auto-interp on {len(todo_features)} features ({len(auto_interp_labels)} already done)")
- new_labels, new_confidences = run_auto_interp(
+ new_labels, new_confidences, new_descriptions = run_auto_interp(
sae,
vocab_logits,
inference,
@@ -957,11 +1013,13 @@ def main(): # noqa: D103
llm_provider=args.llm_provider,
llm_model=args.llm_model,
num_workers=args.auto_interp_workers,
+ gsea_context=gsea_context,
)
for f in new_labels:
auto_interp_labels[f] = {
"label": new_labels[f],
"confidence": new_confidences[f],
+ "description": new_descriptions.get(f),
}
with open(auto_interp_ckpt, "w") as f:
json.dump(auto_interp_labels, f, indent=2)
diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/eval_gene_enrichment.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/eval_gene_enrichment.py
index 76d943c64e..1b567d838f 100644
--- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/eval_gene_enrichment.py
+++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/eval_gene_enrichment.py
@@ -52,6 +52,7 @@
from codonfm_sae.eval.gene_enrichment import ( # noqa: E402
ANNOTATION_DATABASES,
GeneEnrichmentReport,
+ detect_gene_families,
download_obo_files,
rollup_go_slim,
run_gene_enrichment,
@@ -499,6 +500,18 @@ def main():
gsea_time = time.time() - t0
print(f"\n GSEA completed in {gsea_time:.1f}s")
+ # 6b. Detect gene families
+ print(" Detecting gene families...")
+ gene_families = detect_gene_families(gene_activations)
+ print(f" {len(gene_families)} features with dominant gene family")
+
+ # Update report label columns with gene families
+ from codonfm_sae.eval.gene_enrichment import build_feature_label_columns
+
+ report.feature_label_columns = build_feature_label_columns(
+ report.per_feature, report.n_features_total, gene_families=gene_families
+ )
+
# 7. Save results (before GO Slim so we don't lose GSEA work on failure)
print("\n" + "=" * 60)
print("SAVING RESULTS")
@@ -524,7 +537,9 @@ def main():
# Rebuild label columns with GO Slim info
from codonfm_sae.eval.gene_enrichment import build_feature_label_columns
- report.feature_label_columns = build_feature_label_columns(report.per_feature, report.n_features_total)
+ report.feature_label_columns = build_feature_label_columns(
+ report.per_feature, report.n_features_total, gene_families=gene_families
+ )
n_slim = sum(1 for fl in report.per_feature if fl.go_slim_name is not None)
slim_names = {fl.go_slim_name for fl in report.per_feature if fl.go_slim_name is not None}
diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/train.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/train.py
index 1762fe8ca9..c610581e61 100644
--- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/train.py
+++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/scripts/train.py
@@ -80,6 +80,15 @@ def parse_args(): # noqa: D103
train_group.add_argument("--max-grad-norm", type=float, default=None)
train_group.add_argument("--lr-scale-with-latents", action=argparse.BooleanOptionalAction, default=False)
train_group.add_argument("--lr-reference-hidden-dim", type=int, default=2048)
+ train_group.add_argument("--warmup-steps", type=int, default=0, help="Linear LR warmup steps")
+ train_group.add_argument(
+ "--lr-schedule", type=str, default="constant", choices=["constant", "cosine", "linear"],
+ help="LR schedule after warmup",
+ )
+ train_group.add_argument("--lr-min", type=float, default=0.0, help="Minimum LR for decay schedules")
+ train_group.add_argument(
+ "--lr-decay-steps", type=int, default=None, help="Total steps for LR decay (None = full training)",
+ )
# W&B
wb_group = p.add_argument_group("Weights & Biases")
@@ -147,6 +156,11 @@ def build_training_config(args, device: str) -> TrainingConfig: # noqa: D103
checkpoint_steps=args.checkpoint_steps,
lr_scale_with_latents=args.lr_scale_with_latents,
lr_reference_hidden_dim=args.lr_reference_hidden_dim,
+ warmup_steps=args.warmup_steps,
+ max_grad_norm=args.max_grad_norm,
+ lr_schedule=args.lr_schedule,
+ lr_min=args.lr_min,
+ lr_decay_steps=args.lr_decay_steps,
)
@@ -277,7 +291,6 @@ def main(): # noqa: D103
print(f"[rank {rank}] capped to {min_batches} batches/epoch for DDP sync")
trainer.fit(
dataloader,
- max_grad_norm=args.max_grad_norm,
resume_from=args.resume_from,
data_sharded=True,
)
@@ -292,7 +305,6 @@ def main(): # noqa: D103
trainer.fit(
activations_flat,
- max_grad_norm=args.max_grad_norm,
resume_from=args.resume_from,
)
diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/src/codonfm_sae/eval/gene_enrichment.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/src/codonfm_sae/eval/gene_enrichment.py
index 2299756a50..dda6684ee1 100644
--- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/src/codonfm_sae/eval/gene_enrichment.py
+++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/codonfm/src/codonfm_sae/eval/gene_enrichment.py
@@ -187,6 +187,8 @@ def run_gsea_for_feature(
res = gseapy.prerank(
rnk=series,
gene_sets=db,
+ min_size=5,
+ max_size=1000,
no_plot=True,
outdir=None,
verbose=False,
@@ -201,8 +203,6 @@ def run_gsea_for_feature(
fdr_col = "FDR q-val" if "FDR q-val" in df.columns else "fdr"
es_col = "NES" if "NES" in df.columns else "nes"
pval_col = "NOM p-val" if "NOM p-val" in df.columns else "pval"
- geneset_size_col = "Gene %" if "Gene %" in df.columns else "geneset_size"
-
df[fdr_col] = pd.to_numeric(df[fdr_col], errors="coerce")
df = df.dropna(subset=[fdr_col])
@@ -221,7 +221,15 @@ def run_gsea_for_feature(
fdr_val = float(best_row[fdr_col])
es_val = float(best_row.get(es_col, 0.0))
pval = float(best_row.get(pval_col, 1.0))
- n_genes = int(best_row.get(geneset_size_col, 0)) if geneset_size_col in df.columns else 0
+
+ # Parse n_genes from "Tag %" column (format: "6/200") or fall back to 0
+ n_genes = 0
+ tag_pct = str(best_row.get("Tag %", ""))
+ if "/" in tag_pct:
+ try:
+ n_genes = int(tag_pct.split("/")[1])
+ except (ValueError, IndexError):
+ pass
result = EnrichmentResult(
feature_idx=feature_idx,
@@ -248,6 +256,13 @@ def run_gsea_for_feature(
continue # Already added the best
t_id = _parse_go_id(t_raw) if is_go else t_raw
t_name = _parse_term_name(t_raw) if is_go else t_raw
+ row_n_genes = 0
+ row_tag = str(row.get("Tag %", ""))
+ if "/" in row_tag:
+ try:
+ row_n_genes = int(row_tag.split("/")[1])
+ except (ValueError, IndexError):
+ pass
all_significant.append(
EnrichmentResult(
feature_idx=feature_idx,
@@ -257,7 +272,7 @@ def run_gsea_for_feature(
enrichment_score=float(row.get(es_col, 0.0)),
pvalue=float(row.get(pval_col, 1.0)),
fdr=float(row[fdr_col]),
- n_genes_in_term=int(row.get(geneset_size_col, 0)) if geneset_size_col in df.columns else 0,
+ n_genes_in_term=row_n_genes,
)
)
@@ -475,12 +490,58 @@ def _find_slim_ancestor(go_id: str) -> Optional[Tuple[str, str]]:
return feature_labels
+# ── Gene family detection ────────────────────────────────────────────────
+
+
+def _gene_prefix(gene_name: str) -> str:
+ """Extract the alphabetic prefix of a gene name (letters before first digit)."""
+ prefix = ""
+ for c in gene_name:
+ if c.isdigit():
+ break
+ prefix += c
+ return prefix
+
+
+def detect_gene_families(
+ gene_activations: Dict[int, Dict[str, float]],
+ top_k: int = 10,
+ min_fraction: float = 0.5,
+) -> Dict[int, str]:
+ """Detect dominant gene family for each feature based on top-K gene name prefixes.
+
+ Args:
+ gene_activations: feature_idx -> gene_name -> activation score.
+ top_k: Number of top genes to examine per feature.
+ min_fraction: Minimum fraction of top-K genes sharing a prefix to call it a family.
+
+ Returns:
+ feature_idx -> gene family label (e.g., "OR family (8/10)") or absent if no family.
+ """
+ from collections import Counter
+
+ result = {}
+ for feat_idx, gene_scores in gene_activations.items():
+ top_genes = sorted(gene_scores.keys(), key=lambda g: gene_scores[g], reverse=True)[:top_k]
+ if len(top_genes) < 3:
+ continue
+ prefixes = [_gene_prefix(g) for g in top_genes]
+ counts = Counter(p for p in prefixes if len(p) >= 2)
+ if not counts:
+ continue
+ top_prefix, top_count = counts.most_common(1)[0]
+ if top_count / len(top_genes) >= min_fraction:
+ result[feat_idx] = f"{top_prefix} family ({top_count}/{len(top_genes)})"
+ return result
+
+
# ── Label columns for UMAP ──────────────────────────────────────────────
def build_feature_label_columns(
per_feature: List[FeatureLabels],
n_features: int,
+ gene_families: Optional[Dict[int, str]] = None,
) -> Dict[str, Dict[int, str]]:
"""Build dict[column_name, dict[feature_idx, label]] for UMAP dropdown.
@@ -503,6 +564,7 @@ def build_feature_label_columns(
"InterPro_Domains": {},
"Pfam_Domains": {},
"GO_Slim": {},
+ "gene_family": {},
}
for fl in per_feature:
@@ -525,6 +587,11 @@ def build_feature_label_columns(
else:
columns["GO_Slim"][idx] = "unlabeled"
+ if gene_families and idx in gene_families:
+ columns["gene_family"][idx] = gene_families[idx]
+ else:
+ columns["gene_family"][idx] = "unlabeled"
+
# Fill missing feature indices with "unlabeled"
for col in columns:
for i in range(n_features):
diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/esm2/run.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/esm2/run.py
index 953acf68df..b10df6378b 100644
--- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/esm2/run.py
+++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/esm2/run.py
@@ -143,6 +143,14 @@ def run_train(cfg: DictConfig, cache_dir: Path, output_dir: Path) -> None:
cmd.append("--normalize-input")
if t.get("max_grad_norm"):
cmd.extend(["--max-grad-norm", str(t.max_grad_norm)])
+ if t.get("lr_schedule", "constant") != "constant":
+ cmd.extend(["--lr-schedule", str(t.lr_schedule)])
+ if t.get("lr_min", 0.0) != 0.0:
+ cmd.extend(["--lr-min", str(t.lr_min)])
+ if t.get("lr_decay_steps"):
+ cmd.extend(["--lr-decay-steps", str(t.lr_decay_steps)])
+ if t.get("warmup_steps", 0) > 0:
+ cmd.extend(["--warmup-steps", str(t.warmup_steps)])
# W&B
if t.wandb_enabled:
diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/esm2/run_configs/config.yaml b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/esm2/run_configs/config.yaml
index 63244a1f70..979ef8cc94 100644
--- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/esm2/run_configs/config.yaml
+++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/esm2/run_configs/config.yaml
@@ -48,6 +48,11 @@ train:
batch_size: 4096
log_interval: 50
checkpoint_steps: 999999
+ max_grad_norm: null
+ lr_schedule: constant
+ lr_min: 0.0
+ lr_decay_steps: null
+ warmup_steps: 0
wandb_enabled: false
wandb_project: sae_esm2_recipe
diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/esm2/scripts/train.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/esm2/scripts/train.py
index 6d1aa6e37e..5f7df9472f 100644
--- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/esm2/scripts/train.py
+++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/esm2/scripts/train.py
@@ -87,6 +87,15 @@ def parse_args():
train_group.add_argument("--lr-scale-with-latents", action=argparse.BooleanOptionalAction, default=False)
train_group.add_argument("--lr-reference-hidden-dim", type=int, default=2048)
train_group.add_argument("--grad-accumulation-steps", type=int, default=1, help="Gradient accumulation steps")
+ train_group.add_argument("--warmup-steps", type=int, default=0, help="Linear LR warmup steps")
+ train_group.add_argument(
+ "--lr-schedule", type=str, default="constant", choices=["constant", "cosine", "linear"],
+ help="LR schedule after warmup",
+ )
+ train_group.add_argument("--lr-min", type=float, default=0.0, help="Minimum LR for decay schedules")
+ train_group.add_argument(
+ "--lr-decay-steps", type=int, default=None, help="Total steps for LR decay (None = full training)",
+ )
# W&B
wb_group = p.add_argument_group("Weights & Biases")
@@ -157,6 +166,11 @@ def build_training_config(args, device: str) -> TrainingConfig:
lr_scale_with_latents=args.lr_scale_with_latents,
lr_reference_hidden_dim=args.lr_reference_hidden_dim,
grad_accumulation_steps=args.grad_accumulation_steps,
+ warmup_steps=args.warmup_steps,
+ max_grad_norm=args.max_grad_norm,
+ lr_schedule=args.lr_schedule,
+ lr_min=args.lr_min,
+ lr_decay_steps=args.lr_decay_steps,
)
@@ -288,7 +302,6 @@ def main():
print(f"[rank {rank}] capped to {min_batches} batches/epoch for DDP sync")
trainer.fit(
dataloader,
- max_grad_norm=args.max_grad_norm,
resume_from=args.resume_from,
data_sharded=True,
)
@@ -304,7 +317,6 @@ def main():
trainer.fit(
activations_flat,
- max_grad_norm=args.max_grad_norm,
resume_from=args.resume_from,
)
diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/training.py b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/training.py
index c9d8a9070a..e04213ba01 100644
--- a/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/training.py
+++ b/bionemo-recipes/interpretability/sparse_autoencoders/sae/src/sae/training.py
@@ -20,6 +20,7 @@
"""
import contextlib
+import math
import os
from dataclasses import dataclass, field
from pathlib import Path
@@ -64,6 +65,10 @@ class TrainingConfig:
lr_reference_hidden_dim: Reference hidden_dim for LR scaling (default 2048)
warmup_steps: Number of steps for linear LR warmup (0 = no warmup)
grad_accumulation_steps: Number of microsteps to accumulate gradients before an optimizer step (1 = no accumulation)
+ max_grad_norm: Max gradient norm for clipping (None = no clipping)
+ lr_schedule: LR schedule after warmup ('constant', 'cosine', 'linear')
+ lr_min: Minimum LR for decay schedules
+ lr_decay_steps: Total steps for LR decay (None = use full training duration)
"""
lr: float = 3e-4
@@ -80,6 +85,10 @@ class TrainingConfig:
lr_reference_hidden_dim: int = 2048
warmup_steps: int = 0
grad_accumulation_steps: int = 1
+ max_grad_norm: Optional[float] = None
+ lr_schedule: str = "constant"
+ lr_min: float = 0.0
+ lr_decay_steps: Optional[int] = None
@dataclass
@@ -178,12 +187,20 @@ def __init__(
self.is_distributed: bool = False
self._data_sharded: bool = False
+ # Validate lr_schedule
+ valid_schedules = ("constant", "cosine", "linear")
+ if self.config.lr_schedule not in valid_schedules:
+ raise ValueError(
+ f"Unknown lr_schedule: {self.config.lr_schedule!r}. Expected one of {valid_schedules}."
+ )
+
# Will be set during training
self.dataloader: Optional[DataLoader] = None
self.wandb_run = None
self.global_step: int = 0
self.current_epoch: int = 0
self._target_lr: float = config.lr if config else 3e-4
+ self._total_decay_steps: int = 0 # computed in fit() once we know total steps
def _setup_dataloader(self, data: Union[torch.Tensor, DataLoader]) -> DataLoader:
"""Setup dataloader from tensor or existing dataloader."""
@@ -234,8 +251,6 @@ def _compute_effective_lr(self) -> float:
# Scale: lr ∝ 1/sqrt(hidden_dim)
# effective_lr = base_lr * sqrt(reference_dim / hidden_dim)
- import math
-
scale_factor = math.sqrt(reference_dim / hidden_dim)
effective_lr = base_lr * scale_factor
@@ -256,16 +271,51 @@ def _setup_optimizer(self) -> torch.optim.Optimizer:
self._target_lr = self.config.lr
return self.optimizer
- def _get_warmup_lr(self, step: int) -> float:
- """Compute learning rate with linear warmup."""
- if self.config.warmup_steps == 0 or step >= self.config.warmup_steps:
- return self._target_lr
- # Linear warmup: lr = target_lr * (step / warmup_steps)
- return self._target_lr * (step / self.config.warmup_steps)
+ def _get_lr(self, step: int) -> float:
+ """Compute learning rate with warmup and optional decay schedule.
+
+ The schedule has two phases:
+ 1. Warmup (steps 0..warmup_steps-1): linear ramp from 0 to target_lr
+ 2. Decay (steps warmup_steps..warmup_steps+decay_steps): schedule-dependent decay
+
+ Args:
+ step: Current global training step.
+
+ Returns:
+ Learning rate for this step.
+ """
+ warmup_steps = self.config.warmup_steps
+ target_lr = self._target_lr
+ lr_min = self.config.lr_min
+
+ # Phase 1: warmup
+ if warmup_steps > 0 and step < warmup_steps:
+ return target_lr * (step / warmup_steps)
+
+ # Phase 2: decay (or constant)
+ schedule = self.config.lr_schedule
+ if schedule == "constant":
+ return target_lr
+
+ decay_steps = self._total_decay_steps
+ if decay_steps <= 0:
+ return target_lr
+
+ # How far through the decay phase we are (0.0 to 1.0, clamped)
+ steps_since_warmup = step - warmup_steps
+ progress = min(steps_since_warmup / decay_steps, 1.0)
+
+ if schedule == "cosine":
+ # Cosine annealing: lr_min + 0.5 * (target - lr_min) * (1 + cos(pi * progress))
+ return lr_min + 0.5 * (target_lr - lr_min) * (1.0 + math.cos(math.pi * progress))
+ elif schedule == "linear":
+ return target_lr + (lr_min - target_lr) * progress
+ else:
+ raise ValueError(f"Unknown lr_schedule: {schedule!r}. Expected 'constant', 'cosine', or 'linear'.")
def _update_lr(self, optimizer: torch.optim.Optimizer, step: int) -> float:
- """Update learning rate based on warmup schedule."""
- lr = self._get_warmup_lr(step)
+ """Update learning rate based on schedule."""
+ lr = self._get_lr(step)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
return lr
@@ -369,6 +419,7 @@ def _setup_wandb(self) -> None:
group=self.wandb_config.group,
job_type=self.wandb_config.job_type,
config=config,
+ settings=wandb.Settings(init_timeout=300),
)
print(f"wandb run: {self.wandb_run.url}")
else:
@@ -462,7 +513,7 @@ def fit(
Args:
data: Training data (tensor or dataloader)
- max_grad_norm: Max gradient norm for clipping (None = no clipping)
+ max_grad_norm: Max gradient norm for clipping (overrides config.max_grad_norm if set)
resume_from: Path to checkpoint to resume training from (optional)
data_sharded: If True, data is already sharded per rank — skip DistributedSampler
**loss_kwargs: Additional arguments passed to loss function
@@ -470,6 +521,9 @@ def fit(
Returns:
Final training loss
"""
+ # Resolve grad clipping: fit() arg overrides config
+ if max_grad_norm is None:
+ max_grad_norm = self.config.max_grad_norm
self._data_sharded = data_sharded
# Setup distributed training first (before moving model to device)
@@ -504,6 +558,25 @@ def fit(
accum_steps = self.config.grad_accumulation_steps
global_batch_size = self.config.batch_size * self.parallel_config.dp_size * accum_steps
+ # Compute total decay steps for LR schedule
+ if self.config.lr_decay_steps is not None:
+ self._total_decay_steps = self.config.lr_decay_steps
+ elif self.config.lr_schedule != "constant":
+ # Estimate total optimizer steps from dataloader length
+ try:
+ batches_per_epoch = len(self.dataloader)
+ steps_per_epoch = batches_per_epoch // accum_steps
+ total_steps = steps_per_epoch * self.config.n_epochs
+ self._total_decay_steps = max(0, total_steps - self.config.warmup_steps)
+ except TypeError:
+ self._total_decay_steps = 0
+ self._print_rank0(
+ "WARNING: Cannot compute decay steps for streaming dataloader. "
+ "Set lr_decay_steps explicitly or use lr_schedule='constant'."
+ )
+ else:
+ self._total_decay_steps = 0
+
remaining_info = ""
if resume_from is not None:
remaining_info = f" (resuming from epoch {self.current_epoch})"
@@ -518,6 +591,13 @@ def fit(
self._print_rank0(f"Gradient accumulation: {accum_steps} microsteps")
if self.config.warmup_steps > 0:
self._print_rank0(f"LR warmup: {self.config.warmup_steps} steps")
+ if self.config.lr_schedule != "constant":
+ self._print_rank0(
+ f"LR schedule: {self.config.lr_schedule} decay over {self._total_decay_steps} steps "
+ f"(lr_min={self.config.lr_min:.2e})"
+ )
+ if max_grad_norm is not None:
+ self._print_rank0(f"Gradient clipping: max_norm={max_grad_norm}")
# If resuming, keep restored global_step and current_epoch; otherwise start fresh
if resume_from is None:
@@ -693,6 +773,9 @@ def train_sae(
device: str = "cuda",
log_interval: int = 1,
warmup_steps: int = 0,
+ max_grad_norm: Optional[float] = None,
+ lr_schedule: str = "constant",
+ lr_min: float = 0.0,
**loss_kwargs,
) -> float:
"""Convenience function to train an SAE model.
@@ -708,6 +791,9 @@ def train_sae(
device: Device to train on
log_interval: Print loss every N epochs
warmup_steps: Number of steps for linear LR warmup (0 = no warmup)
+ max_grad_norm: Max gradient norm for clipping (None = no clipping)
+ lr_schedule: LR schedule after warmup ('constant', 'cosine', 'linear')
+ lr_min: Minimum LR for decay schedules
**loss_kwargs: Additional arguments for loss function
Returns:
@@ -727,6 +813,9 @@ def train_sae(
device=device,
log_interval=log_interval,
warmup_steps=warmup_steps,
+ max_grad_norm=max_grad_norm,
+ lr_schedule=lr_schedule,
+ lr_min=lr_min,
)
trainer = Trainer(