diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..390c65e --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,43 @@ +# MLXLMProbe - Project Context + +## Mechanistic Interpretability Resources + +- **MI Glossary**: https://www.neelnanda.io/mechanistic-interpretability/glossary +- **TransformerLens**: https://github.com/TransformerLensOrg/TransformerLens + +## Experimental: NumPy → MLX GPU Acceleration (2026-01-28) + +**Branch:** `experimental/numpy-to-mlx-gpu` +**Commit:** `d40b2a9` +**Goal:** Replace heavy NumPy computations with MLX for Apple Silicon GPU acceleration (also compatible with mlx-cuda on Linux) + +### MLX Utility Functions Added + +| Function | Purpose | +|----------|---------| +| `mlx_softmax()` | GPU-accelerated softmax | +| `mlx_norm()` | GPU-accelerated L2 norm | +| `mlx_entropy()` | GPU-accelerated entropy calculation | +| `mlx_cosine_similarity_matrix()` | Batch matrix multiply for pairwise similarity | +| `mlx_to_numpy()` / `numpy_to_mlx()` | Conversion helpers | + +### High-Impact Operations Replaced + +| Location | Operation | Change | +|----------|-----------|--------| +| Logit Lens (~line 1920) | Softmax | Compute on GPU before NumPy conversion | +| Token Probs (~line 2380) | Softmax | Compute on GPU before NumPy conversion | +| Full Sequence Logit Lens (~line 2430) | Softmax | Compute on GPU before NumPy conversion | +| Layer Similarity (~line 4630) | O(n²) cosine similarity | Single `mx.matmul()` call | + +### Not Changed (data already NumPy) + +- Entropy calculations in interpretation functions +- Norm calculations in visualization functions +- Statistical operations for plotting (Plotly requires NumPy) + +### Assessment + +- **Probability of success:** 85% +- **Biggest win:** Layer similarity matrix - eliminates O(n²) nested loop +- **Key principle:** Minimize MLX↔NumPy transfers by computing on GPU before conversion diff --git a/mlxlmprobe.py b/mlxlmprobe.py index fd4bc56..8108638 100644 --- a/mlxlmprobe.py +++ b/mlxlmprobe.py @@ -79,6 +79,58 @@ __version__ = "0.1.5" +# ============================================================================= +# MLX GPU-Accelerated Utilities +# ============================================================================= +# These functions leverage MLX for GPU acceleration on Apple Silicon (and mlx-cuda on Linux). +# They avoid unnecessary MLX↔NumPy transfers by keeping computation on GPU. + +def mlx_softmax(x: mx.array, axis: int = -1) -> mx.array: + """GPU-accelerated softmax. Keeps data on GPU.""" + return mx.softmax(x, axis=axis) + + +def mlx_norm(x: mx.array, axis: int = -1, keepdims: bool = False) -> mx.array: + """GPU-accelerated L2 norm.""" + return mx.sqrt(mx.sum(x * x, axis=axis, keepdims=keepdims)) + + +def mlx_entropy(probs: mx.array, axis: int = -1, base: str = "e") -> mx.array: + """GPU-accelerated entropy calculation. + + Args: + probs: Probability distribution (should sum to 1 along axis) + axis: Axis along which to compute entropy + base: "e" for natural log, "2" for log2 + """ + log_fn = mx.log if base == "e" else mx.log2 + return -mx.sum(probs * log_fn(probs + 1e-10), axis=axis) + + +def mlx_cosine_similarity_matrix(vectors: mx.array) -> mx.array: + """GPU-accelerated pairwise cosine similarity. + + Args: + vectors: Shape (n, d) where n is number of vectors, d is dimension + + Returns: + Similarity matrix of shape (n, n) + """ + norms = mlx_norm(vectors, axis=-1, keepdims=True) + normalized = vectors / (norms + 1e-8) + return mx.matmul(normalized, normalized.T) + + +def mlx_to_numpy(x: mx.array) -> np.ndarray: + """Convert MLX array to NumPy for visualization (Plotly, etc.).""" + return np.array(x) + + +def numpy_to_mlx(x: np.ndarray) -> mx.array: + """Convert NumPy array to MLX for GPU computation.""" + return mx.array(x) + + # ============================================================================= # Performance Utilities - Caching and Background Processing # ============================================================================= @@ -1864,23 +1916,23 @@ def _capture_logit_lens(self, layer_idx: int, hidden_state: mx.array): positions_to_capture = [seq_len - 1] for pos in positions_to_capture: - # Extract single position and immediately convert to numpy to free MLX memory - logits_pos = layer_logits[0, pos, :] - mx.eval(logits_pos) + # Extract single position + logits_pos = layer_logits[0, pos, :].astype(mx.float32) + + # Compute probabilities on GPU before converting to NumPy + probs_mx = mlx_softmax(logits_pos) + mx.eval(logits_pos, probs_mx) # Evaluate both together - # Convert to numpy immediately + # Convert to numpy try: - logits_np = np.array(logits_pos.astype(mx.float32)) + logits_np = np.array(logits_pos) + probs = np.array(probs_mx) except (RuntimeError, TypeError, ValueError): logits_np = np.array(logits_pos.tolist(), dtype=np.float32) + probs = np.array(probs_mx.tolist(), dtype=np.float32) - # Free the MLX tensor - del logits_pos - - # Compute probabilities - max_logit = logits_np.max() - exp_logits = np.exp(logits_np - max_logit) - probs = exp_logits / exp_logits.sum() + # Free the MLX tensors + del logits_pos, probs_mx # Get top-5 predictions top_k = 5 @@ -2325,18 +2377,23 @@ def compute_norm(arr: np.ndarray) -> float: mx.eval(logits) logits_slice = logits[0, -1, :].astype(mx.float32) - mx.eval(logits_slice) + + # Compute probabilities on GPU before converting to NumPy + if self.config.capture_token_probs: + probs_mx = mlx_softmax(logits_slice) + mx.eval(logits_slice, probs_mx) # Evaluate both together + try: + self.results.token_probs = np.array(probs_mx) + except (RuntimeError, TypeError, ValueError): + self.results.token_probs = np.array(probs_mx.tolist(), dtype=np.float32) + else: + mx.eval(logits_slice) + try: self.results.logits = np.array(logits_slice) except (RuntimeError, TypeError, ValueError): self.results.logits = np.array(logits_slice.tolist(), dtype=np.float32) - # Compute probabilities - if self.config.capture_token_probs: - max_logit = self.results.logits.max() - exp_logits = np.exp(self.results.logits - max_logit) - self.results.token_probs = exp_logits / exp_logits.sum() - # Get top-k tokens with text top_k = 20 top_indices = np.argsort(self.results.logits)[-top_k:][::-1] @@ -2372,11 +2429,12 @@ def compute_norm(arr: np.ndarray) -> float: for pos in range(min(seq_len, self.config.max_sequence_positions, 512)): if pos not in self.results.logit_lens_by_position: self.results.logit_lens_by_position[pos] = {} - # For non-last positions, compute from full logits - pos_logits = np.array(logits[0, pos, :].astype(mx.float32)) - max_logit = pos_logits.max() - exp_logits = np.exp(pos_logits - max_logit) - pos_probs = exp_logits / exp_logits.sum() + # Compute softmax on GPU before converting to NumPy + pos_logits_mx = logits[0, pos, :].astype(mx.float32) + pos_probs_mx = mlx_softmax(pos_logits_mx) + mx.eval(pos_logits_mx, pos_probs_mx) + pos_logits = np.array(pos_logits_mx) + pos_probs = np.array(pos_probs_mx) pos_top_indices = np.argsort(pos_logits)[-5:][::-1] pos_preds = [] for idx in pos_top_indices: @@ -4568,24 +4626,24 @@ def plot_token_probabilities(results: ProbeResults) -> go.Figure: def plot_layer_similarity(results: ProbeResults) -> go.Figure: - """Plot cosine similarity between layers.""" + """Plot cosine similarity between layers using GPU-accelerated matrix multiplication.""" layer_indices = sorted(results.layer_outputs.keys()) n = len(layer_indices) if n < 2: return go.Figure() - similarity = np.zeros((n, n)) - + # Collect vectors and convert to MLX for GPU computation vectors = [] for idx in layer_indices: vec = results.layer_outputs[idx].flatten() - vec = vec / (np.linalg.norm(vec) + 1e-8) vectors.append(vec) - for i in range(n): - for j in range(n): - similarity[i, j] = np.dot(vectors[i], vectors[j]) + # Stack into matrix and compute similarity on GPU (eliminates O(n²) loop) + vectors_mx = mx.array(np.stack(vectors)) + similarity_mx = mlx_cosine_similarity_matrix(vectors_mx) + mx.eval(similarity_mx) # Force evaluation before converting to NumPy + similarity = mlx_to_numpy(similarity_mx) labels = [f"L{idx}" for idx in layer_indices]