Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# MLXLMProbe - Project Context

## Mechanistic Interpretability Resources

- **MI Glossary**: https://www.neelnanda.io/mechanistic-interpretability/glossary
- **TransformerLens**: https://github.com/TransformerLensOrg/TransformerLens

## Experimental: NumPy → MLX GPU Acceleration (2026-01-28)

**Branch:** `experimental/numpy-to-mlx-gpu`
**Commit:** `d40b2a9`
**Goal:** Replace heavy NumPy computations with MLX for Apple Silicon GPU acceleration (also compatible with mlx-cuda on Linux)

### MLX Utility Functions Added

| Function | Purpose |
|----------|---------|
| `mlx_softmax()` | GPU-accelerated softmax |
| `mlx_norm()` | GPU-accelerated L2 norm |
| `mlx_entropy()` | GPU-accelerated entropy calculation |
| `mlx_cosine_similarity_matrix()` | Batch matrix multiply for pairwise similarity |
| `mlx_to_numpy()` / `numpy_to_mlx()` | Conversion helpers |

### High-Impact Operations Replaced

| Location | Operation | Change |
|----------|-----------|--------|
| Logit Lens (~line 1920) | Softmax | Compute on GPU before NumPy conversion |
| Token Probs (~line 2380) | Softmax | Compute on GPU before NumPy conversion |
| Full Sequence Logit Lens (~line 2430) | Softmax | Compute on GPU before NumPy conversion |
| Layer Similarity (~line 4630) | O(n²) cosine similarity | Single `mx.matmul()` call |

### Not Changed (data already NumPy)

- Entropy calculations in interpretation functions
- Norm calculations in visualization functions
- Statistical operations for plotting (Plotly requires NumPy)

### Assessment

- **Probability of success:** 85%
- **Biggest win:** Layer similarity matrix - eliminates O(n²) nested loop
- **Key principle:** Minimize MLX↔NumPy transfers by computing on GPU before conversion
120 changes: 89 additions & 31 deletions mlxlmprobe.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,58 @@
__version__ = "0.1.5"


# =============================================================================
# MLX GPU-Accelerated Utilities
# =============================================================================
# These functions leverage MLX for GPU acceleration on Apple Silicon (and mlx-cuda on Linux).
# They avoid unnecessary MLX↔NumPy transfers by keeping computation on GPU.

def mlx_softmax(x: mx.array, axis: int = -1) -> mx.array:
"""GPU-accelerated softmax. Keeps data on GPU."""
return mx.softmax(x, axis=axis)


def mlx_norm(x: mx.array, axis: int = -1, keepdims: bool = False) -> mx.array:
"""GPU-accelerated L2 norm."""
return mx.sqrt(mx.sum(x * x, axis=axis, keepdims=keepdims))


Comment on lines +93 to +97
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Handle unexpected base values explicitly in mlx_entropy.

Any base value other than "e" currently falls back to mx.log2, which can silently produce wrong results if a caller passes an unexpected value (e.g. "E", "10"). Consider validating/normalizing the argument (e.g. only allow {"e", "2"} and raise ValueError otherwise) so misconfiguration is caught instead of defaulting to log2.

Suggested change
def mlx_norm(x: mx.array, axis: int = -1, keepdims: bool = False) -> mx.array:
"""GPU-accelerated L2 norm."""
return mx.sqrt(mx.sum(x * x, axis=axis, keepdims=keepdims))
def mlx_norm(x: mx.array, axis: int = -1, keepdims: bool = False) -> mx.array:
"""GPU-accelerated L2 norm."""
return mx.sqrt(mx.sum(x * x, axis=axis, keepdims=keepdims))
def mlx_entropy(probs: mx.array, axis: int = -1, base: str = "e") -> mx.array:
"""GPU-accelerated entropy calculation.
Args:
probs: Probability distribution (should sum to 1 along axis)
axis: Axis along which to compute entropy
base: "e" for natural log, "2" for log2
Raises:
ValueError: If `base` is not one of {"e", "2"}.
"""
if base == "e":
log_fn = mx.log
elif base == "2":
log_fn = mx.log2
else:
raise ValueError(f"Unsupported base for mlx_entropy: {base!r}. Expected 'e' or '2'.")
return -mx.sum(probs * log_fn(probs + 1e-10), axis=axis)

def mlx_entropy(probs: mx.array, axis: int = -1, base: str = "e") -> mx.array:
"""GPU-accelerated entropy calculation.

Args:
probs: Probability distribution (should sum to 1 along axis)
axis: Axis along which to compute entropy
base: "e" for natural log, "2" for log2
"""
log_fn = mx.log if base == "e" else mx.log2
return -mx.sum(probs * log_fn(probs + 1e-10), axis=axis)


def mlx_cosine_similarity_matrix(vectors: mx.array) -> mx.array:
"""GPU-accelerated pairwise cosine similarity.

Args:
vectors: Shape (n, d) where n is number of vectors, d is dimension

Returns:
Similarity matrix of shape (n, n)
"""
norms = mlx_norm(vectors, axis=-1, keepdims=True)
normalized = vectors / (norms + 1e-8)
return mx.matmul(normalized, normalized.T)


def mlx_to_numpy(x: mx.array) -> np.ndarray:
"""Convert MLX array to NumPy for visualization (Plotly, etc.)."""
return np.array(x)


def numpy_to_mlx(x: np.ndarray) -> mx.array:
"""Convert NumPy array to MLX for GPU computation."""
return mx.array(x)


# =============================================================================
# Performance Utilities - Caching and Background Processing
# =============================================================================
Expand Down Expand Up @@ -1864,23 +1916,23 @@ def _capture_logit_lens(self, layer_idx: int, hidden_state: mx.array):
positions_to_capture = [seq_len - 1]

for pos in positions_to_capture:
# Extract single position and immediately convert to numpy to free MLX memory
logits_pos = layer_logits[0, pos, :]
mx.eval(logits_pos)
# Extract single position
logits_pos = layer_logits[0, pos, :].astype(mx.float32)

# Compute probabilities on GPU before converting to NumPy
probs_mx = mlx_softmax(logits_pos)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider extracting shared MLX→NumPy conversion and GPU-handling into small helpers so the probing and plotting code stays simpler and more focused on NumPy-level logic.

You can reduce the added complexity by centralizing the MLX→NumPy conversion and separating GPU concerns from the probing logic.

1. Factor out a safe MLX→NumPy helper

The try/except + tolist() pattern is duplicated and mixes evaluation with conversion:

probs_mx = mlx_softmax(logits_pos)
mx.eval(logits_pos, probs_mx)
try:
    logits_np = np.array(logits_pos)
    probs = np.array(probs_mx)
except (RuntimeError, TypeError, ValueError):
    logits_np = np.array(logits_pos.tolist(), dtype=np.float32)
    probs = np.array(probs_mx.tolist(), dtype=np.float32)

Introduce a small helper that you can reuse everywhere:

def safe_mlx_to_numpy(x: mx.array, dtype=np.float32) -> np.ndarray:
    mx.eval(x)
    try:
        return np.array(x, dtype=dtype)
    except (RuntimeError, TypeError, ValueError):
        return np.array(x.tolist(), dtype=dtype)

Then simplify the call sites:

logits_pos = layer_logits[0, pos, :].astype(mx.float32)
probs_mx = mlx_softmax(logits_pos)

logits_np = safe_mlx_to_numpy(logits_pos)
probs = safe_mlx_to_numpy(probs_mx)

del logits_pos, probs_mx

And similarly in _capture_logit_lens:

logits_slice = logits[0, -1, :].astype(mx.float32)

if self.config.capture_token_probs:
    probs_mx = mlx_softmax(logits_slice)
    self.results.token_probs = safe_mlx_to_numpy(probs_mx)

self.results.logits = safe_mlx_to_numpy(logits_slice)

And in the full-sequence block:

pos_logits_mx = logits[0, pos, :].astype(mx.float32)
pos_probs_mx = mlx_softmax(pos_logits_mx)

pos_logits = safe_mlx_to_numpy(pos_logits_mx)
pos_probs = safe_mlx_to_numpy(pos_probs_mx)

This keeps behavior identical but removes repeated try/except and evaluation boilerplate.

2. Separate “what” from “how” in the probing logic

Right now _capture_logit_lens interleaves:

  • computing softmax on GPU,
  • evaluating MLX tensors,
  • converting to NumPy,
  • assigning to self.results.

You can isolate the GPU handling into a tiny utility and keep probing logic focused on results:

def compute_logits_and_probs(
    logits: mx.array, capture_probs: bool
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
    logits_slice = logits[0, -1, :].astype(mx.float32)
    if capture_probs:
        probs_mx = mlx_softmax(logits_slice)
        logits_np = safe_mlx_to_numpy(logits_slice)
        probs_np = safe_mlx_to_numpy(probs_mx)
        return logits_np, probs_np
    else:
        logits_np = safe_mlx_to_numpy(logits_slice)
        return logits_np, None

Usage in _capture_logit_lens:

self.results.logits, self.results.token_probs = compute_logits_and_probs(
    logits, self.config.capture_token_probs
)

The rest of the method can then work purely with NumPy and business logic.

3. Simplify plot_layer_similarity GPU path

You already have mlx_to_numpy and don’t need the extra staging via NumPy stacking:

vectors_mx = mx.stack(
    [mx.array(results.layer_outputs[idx].flatten()) for idx in layer_indices]
)
similarity_mx = mlx_cosine_similarity_matrix(vectors_mx)
similarity = mlx_to_numpy(similarity_mx)

This removes the explicit np.zeros initialization, the manual vectors list plus np.stack, and the separate mx.eval(similarity_mx) call (since mlx_to_numpy can handle evaluation if you fold that into it in the future).

mx.eval(logits_pos, probs_mx) # Evaluate both together

# Convert to numpy immediately
# Convert to numpy
try:
logits_np = np.array(logits_pos.astype(mx.float32))
logits_np = np.array(logits_pos)
probs = np.array(probs_mx)
except (RuntimeError, TypeError, ValueError):
logits_np = np.array(logits_pos.tolist(), dtype=np.float32)
probs = np.array(probs_mx.tolist(), dtype=np.float32)

# Free the MLX tensor
del logits_pos

# Compute probabilities
max_logit = logits_np.max()
exp_logits = np.exp(logits_np - max_logit)
probs = exp_logits / exp_logits.sum()
# Free the MLX tensors
del logits_pos, probs_mx

# Get top-5 predictions
top_k = 5
Expand Down Expand Up @@ -2325,18 +2377,23 @@ def compute_norm(arr: np.ndarray) -> float:

mx.eval(logits)
logits_slice = logits[0, -1, :].astype(mx.float32)
mx.eval(logits_slice)

# Compute probabilities on GPU before converting to NumPy
if self.config.capture_token_probs:
probs_mx = mlx_softmax(logits_slice)
mx.eval(logits_slice, probs_mx) # Evaluate both together
try:
self.results.token_probs = np.array(probs_mx)
except (RuntimeError, TypeError, ValueError):
self.results.token_probs = np.array(probs_mx.tolist(), dtype=np.float32)
else:
mx.eval(logits_slice)

try:
self.results.logits = np.array(logits_slice)
except (RuntimeError, TypeError, ValueError):
self.results.logits = np.array(logits_slice.tolist(), dtype=np.float32)

# Compute probabilities
if self.config.capture_token_probs:
max_logit = self.results.logits.max()
exp_logits = np.exp(self.results.logits - max_logit)
self.results.token_probs = exp_logits / exp_logits.sum()

# Get top-k tokens with text
top_k = 20
top_indices = np.argsort(self.results.logits)[-top_k:][::-1]
Expand Down Expand Up @@ -2372,11 +2429,12 @@ def compute_norm(arr: np.ndarray) -> float:
for pos in range(min(seq_len, self.config.max_sequence_positions, 512)):
if pos not in self.results.logit_lens_by_position:
self.results.logit_lens_by_position[pos] = {}
# For non-last positions, compute from full logits
pos_logits = np.array(logits[0, pos, :].astype(mx.float32))
max_logit = pos_logits.max()
exp_logits = np.exp(pos_logits - max_logit)
pos_probs = exp_logits / exp_logits.sum()
# Compute softmax on GPU before converting to NumPy
pos_logits_mx = logits[0, pos, :].astype(mx.float32)
pos_probs_mx = mlx_softmax(pos_logits_mx)
mx.eval(pos_logits_mx, pos_probs_mx)
pos_logits = np.array(pos_logits_mx)
pos_probs = np.array(pos_probs_mx)
pos_top_indices = np.argsort(pos_logits)[-5:][::-1]
pos_preds = []
Comment on lines +2433 to 2439
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Explicitly free per-position MLX tensors inside the loop to reduce peak GPU memory usage.

Within this per-position loop, pos_logits_mx and pos_probs_mx are reallocated each iteration and only freed when the loop exits. For long sequences and many layers, this can spike peak GPU memory. Similar to the earlier del logits_pos, probs_mx, consider adding del pos_logits_mx, pos_probs_mx at the end of the loop body to keep GPU memory usage lower during logit lens capture.

Suggested change
pos_logits_mx = logits[0, pos, :].astype(mx.float32)
pos_probs_mx = mlx_softmax(pos_logits_mx)
mx.eval(pos_logits_mx, pos_probs_mx)
pos_logits = np.array(pos_logits_mx)
pos_probs = np.array(pos_probs_mx)
pos_top_indices = np.argsort(pos_logits)[-5:][::-1]
pos_preds = []
pos_logits_mx = logits[0, pos, :].astype(mx.float32)
pos_probs_mx = mlx_softmax(pos_logits_mx)
mx.eval(pos_logits_mx, pos_probs_mx)
pos_logits = np.array(pos_logits_mx)
pos_probs = np.array(pos_probs_mx)
# Free per-position MLX tensors to reduce peak GPU memory usage
del pos_logits_mx, pos_probs_mx
pos_top_indices = np.argsort(pos_logits)[-5:][::-1]
pos_preds = []

for idx in pos_top_indices:
Expand Down Expand Up @@ -4568,24 +4626,24 @@ def plot_token_probabilities(results: ProbeResults) -> go.Figure:


def plot_layer_similarity(results: ProbeResults) -> go.Figure:
"""Plot cosine similarity between layers."""
"""Plot cosine similarity between layers using GPU-accelerated matrix multiplication."""
layer_indices = sorted(results.layer_outputs.keys())
n = len(layer_indices)

if n < 2:
return go.Figure()

similarity = np.zeros((n, n))

# Collect vectors and convert to MLX for GPU computation
vectors = []
for idx in layer_indices:
vec = results.layer_outputs[idx].flatten()
vec = vec / (np.linalg.norm(vec) + 1e-8)
vectors.append(vec)

for i in range(n):
for j in range(n):
similarity[i, j] = np.dot(vectors[i], vectors[j])
# Stack into matrix and compute similarity on GPU (eliminates O(n²) loop)
vectors_mx = mx.array(np.stack(vectors))
similarity_mx = mlx_cosine_similarity_matrix(vectors_mx)
mx.eval(similarity_mx) # Force evaluation before converting to NumPy
similarity = mlx_to_numpy(similarity_mx)

labels = [f"L{idx}" for idx in layer_indices]

Expand Down