-
Notifications
You must be signed in to change notification settings - Fork 0
Experimental: Replace NumPy with MLX for GPU acceleration #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| # MLXLMProbe - Project Context | ||
|
|
||
| ## Mechanistic Interpretability Resources | ||
|
|
||
| - **MI Glossary**: https://www.neelnanda.io/mechanistic-interpretability/glossary | ||
| - **TransformerLens**: https://github.com/TransformerLensOrg/TransformerLens | ||
|
|
||
| ## Experimental: NumPy → MLX GPU Acceleration (2026-01-28) | ||
|
|
||
| **Branch:** `experimental/numpy-to-mlx-gpu` | ||
| **Commit:** `d40b2a9` | ||
| **Goal:** Replace heavy NumPy computations with MLX for Apple Silicon GPU acceleration (also compatible with mlx-cuda on Linux) | ||
|
|
||
| ### MLX Utility Functions Added | ||
|
|
||
| | Function | Purpose | | ||
| |----------|---------| | ||
| | `mlx_softmax()` | GPU-accelerated softmax | | ||
| | `mlx_norm()` | GPU-accelerated L2 norm | | ||
| | `mlx_entropy()` | GPU-accelerated entropy calculation | | ||
| | `mlx_cosine_similarity_matrix()` | Batch matrix multiply for pairwise similarity | | ||
| | `mlx_to_numpy()` / `numpy_to_mlx()` | Conversion helpers | | ||
|
|
||
| ### High-Impact Operations Replaced | ||
|
|
||
| | Location | Operation | Change | | ||
| |----------|-----------|--------| | ||
| | Logit Lens (~line 1920) | Softmax | Compute on GPU before NumPy conversion | | ||
| | Token Probs (~line 2380) | Softmax | Compute on GPU before NumPy conversion | | ||
| | Full Sequence Logit Lens (~line 2430) | Softmax | Compute on GPU before NumPy conversion | | ||
| | Layer Similarity (~line 4630) | O(n²) cosine similarity | Single `mx.matmul()` call | | ||
|
|
||
| ### Not Changed (data already NumPy) | ||
|
|
||
| - Entropy calculations in interpretation functions | ||
| - Norm calculations in visualization functions | ||
| - Statistical operations for plotting (Plotly requires NumPy) | ||
|
|
||
| ### Assessment | ||
|
|
||
| - **Probability of success:** 85% | ||
| - **Biggest win:** Layer similarity matrix - eliminates O(n²) nested loop | ||
| - **Key principle:** Minimize MLX↔NumPy transfers by computing on GPU before conversion |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -79,6 +79,58 @@ | |||||||||||||||||||||||||||||||||
| __version__ = "0.1.5" | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # ============================================================================= | ||||||||||||||||||||||||||||||||||
| # MLX GPU-Accelerated Utilities | ||||||||||||||||||||||||||||||||||
| # ============================================================================= | ||||||||||||||||||||||||||||||||||
| # These functions leverage MLX for GPU acceleration on Apple Silicon (and mlx-cuda on Linux). | ||||||||||||||||||||||||||||||||||
| # They avoid unnecessary MLX↔NumPy transfers by keeping computation on GPU. | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def mlx_softmax(x: mx.array, axis: int = -1) -> mx.array: | ||||||||||||||||||||||||||||||||||
| """GPU-accelerated softmax. Keeps data on GPU.""" | ||||||||||||||||||||||||||||||||||
| return mx.softmax(x, axis=axis) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def mlx_norm(x: mx.array, axis: int = -1, keepdims: bool = False) -> mx.array: | ||||||||||||||||||||||||||||||||||
| """GPU-accelerated L2 norm.""" | ||||||||||||||||||||||||||||||||||
| return mx.sqrt(mx.sum(x * x, axis=axis, keepdims=keepdims)) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def mlx_entropy(probs: mx.array, axis: int = -1, base: str = "e") -> mx.array: | ||||||||||||||||||||||||||||||||||
| """GPU-accelerated entropy calculation. | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||
| probs: Probability distribution (should sum to 1 along axis) | ||||||||||||||||||||||||||||||||||
| axis: Axis along which to compute entropy | ||||||||||||||||||||||||||||||||||
| base: "e" for natural log, "2" for log2 | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| log_fn = mx.log if base == "e" else mx.log2 | ||||||||||||||||||||||||||||||||||
| return -mx.sum(probs * log_fn(probs + 1e-10), axis=axis) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def mlx_cosine_similarity_matrix(vectors: mx.array) -> mx.array: | ||||||||||||||||||||||||||||||||||
| """GPU-accelerated pairwise cosine similarity. | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||
| vectors: Shape (n, d) where n is number of vectors, d is dimension | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||
| Similarity matrix of shape (n, n) | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| norms = mlx_norm(vectors, axis=-1, keepdims=True) | ||||||||||||||||||||||||||||||||||
| normalized = vectors / (norms + 1e-8) | ||||||||||||||||||||||||||||||||||
| return mx.matmul(normalized, normalized.T) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def mlx_to_numpy(x: mx.array) -> np.ndarray: | ||||||||||||||||||||||||||||||||||
| """Convert MLX array to NumPy for visualization (Plotly, etc.).""" | ||||||||||||||||||||||||||||||||||
| return np.array(x) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def numpy_to_mlx(x: np.ndarray) -> mx.array: | ||||||||||||||||||||||||||||||||||
| """Convert NumPy array to MLX for GPU computation.""" | ||||||||||||||||||||||||||||||||||
| return mx.array(x) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # ============================================================================= | ||||||||||||||||||||||||||||||||||
| # Performance Utilities - Caching and Background Processing | ||||||||||||||||||||||||||||||||||
| # ============================================================================= | ||||||||||||||||||||||||||||||||||
|
|
@@ -1864,23 +1916,23 @@ def _capture_logit_lens(self, layer_idx: int, hidden_state: mx.array): | |||||||||||||||||||||||||||||||||
| positions_to_capture = [seq_len - 1] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| for pos in positions_to_capture: | ||||||||||||||||||||||||||||||||||
| # Extract single position and immediately convert to numpy to free MLX memory | ||||||||||||||||||||||||||||||||||
| logits_pos = layer_logits[0, pos, :] | ||||||||||||||||||||||||||||||||||
| mx.eval(logits_pos) | ||||||||||||||||||||||||||||||||||
| # Extract single position | ||||||||||||||||||||||||||||||||||
| logits_pos = layer_logits[0, pos, :].astype(mx.float32) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Compute probabilities on GPU before converting to NumPy | ||||||||||||||||||||||||||||||||||
| probs_mx = mlx_softmax(logits_pos) | ||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (complexity): Consider extracting shared MLX→NumPy conversion and GPU-handling into small helpers so the probing and plotting code stays simpler and more focused on NumPy-level logic. You can reduce the added complexity by centralizing the MLX→NumPy conversion and separating GPU concerns from the probing logic. 1. Factor out a safe MLX→NumPy helperThe probs_mx = mlx_softmax(logits_pos)
mx.eval(logits_pos, probs_mx)
try:
logits_np = np.array(logits_pos)
probs = np.array(probs_mx)
except (RuntimeError, TypeError, ValueError):
logits_np = np.array(logits_pos.tolist(), dtype=np.float32)
probs = np.array(probs_mx.tolist(), dtype=np.float32)Introduce a small helper that you can reuse everywhere: def safe_mlx_to_numpy(x: mx.array, dtype=np.float32) -> np.ndarray:
mx.eval(x)
try:
return np.array(x, dtype=dtype)
except (RuntimeError, TypeError, ValueError):
return np.array(x.tolist(), dtype=dtype)Then simplify the call sites: logits_pos = layer_logits[0, pos, :].astype(mx.float32)
probs_mx = mlx_softmax(logits_pos)
logits_np = safe_mlx_to_numpy(logits_pos)
probs = safe_mlx_to_numpy(probs_mx)
del logits_pos, probs_mxAnd similarly in logits_slice = logits[0, -1, :].astype(mx.float32)
if self.config.capture_token_probs:
probs_mx = mlx_softmax(logits_slice)
self.results.token_probs = safe_mlx_to_numpy(probs_mx)
self.results.logits = safe_mlx_to_numpy(logits_slice)And in the full-sequence block: pos_logits_mx = logits[0, pos, :].astype(mx.float32)
pos_probs_mx = mlx_softmax(pos_logits_mx)
pos_logits = safe_mlx_to_numpy(pos_logits_mx)
pos_probs = safe_mlx_to_numpy(pos_probs_mx)This keeps behavior identical but removes repeated 2. Separate “what” from “how” in the probing logicRight now
You can isolate the GPU handling into a tiny utility and keep probing logic focused on results: def compute_logits_and_probs(
logits: mx.array, capture_probs: bool
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
logits_slice = logits[0, -1, :].astype(mx.float32)
if capture_probs:
probs_mx = mlx_softmax(logits_slice)
logits_np = safe_mlx_to_numpy(logits_slice)
probs_np = safe_mlx_to_numpy(probs_mx)
return logits_np, probs_np
else:
logits_np = safe_mlx_to_numpy(logits_slice)
return logits_np, NoneUsage in self.results.logits, self.results.token_probs = compute_logits_and_probs(
logits, self.config.capture_token_probs
)The rest of the method can then work purely with NumPy and business logic. 3. Simplify
|
||||||||||||||||||||||||||||||||||
| mx.eval(logits_pos, probs_mx) # Evaluate both together | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Convert to numpy immediately | ||||||||||||||||||||||||||||||||||
| # Convert to numpy | ||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||
| logits_np = np.array(logits_pos.astype(mx.float32)) | ||||||||||||||||||||||||||||||||||
| logits_np = np.array(logits_pos) | ||||||||||||||||||||||||||||||||||
| probs = np.array(probs_mx) | ||||||||||||||||||||||||||||||||||
| except (RuntimeError, TypeError, ValueError): | ||||||||||||||||||||||||||||||||||
| logits_np = np.array(logits_pos.tolist(), dtype=np.float32) | ||||||||||||||||||||||||||||||||||
| probs = np.array(probs_mx.tolist(), dtype=np.float32) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Free the MLX tensor | ||||||||||||||||||||||||||||||||||
| del logits_pos | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Compute probabilities | ||||||||||||||||||||||||||||||||||
| max_logit = logits_np.max() | ||||||||||||||||||||||||||||||||||
| exp_logits = np.exp(logits_np - max_logit) | ||||||||||||||||||||||||||||||||||
| probs = exp_logits / exp_logits.sum() | ||||||||||||||||||||||||||||||||||
| # Free the MLX tensors | ||||||||||||||||||||||||||||||||||
| del logits_pos, probs_mx | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Get top-5 predictions | ||||||||||||||||||||||||||||||||||
| top_k = 5 | ||||||||||||||||||||||||||||||||||
|
|
@@ -2325,18 +2377,23 @@ def compute_norm(arr: np.ndarray) -> float: | |||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| mx.eval(logits) | ||||||||||||||||||||||||||||||||||
| logits_slice = logits[0, -1, :].astype(mx.float32) | ||||||||||||||||||||||||||||||||||
| mx.eval(logits_slice) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Compute probabilities on GPU before converting to NumPy | ||||||||||||||||||||||||||||||||||
| if self.config.capture_token_probs: | ||||||||||||||||||||||||||||||||||
| probs_mx = mlx_softmax(logits_slice) | ||||||||||||||||||||||||||||||||||
| mx.eval(logits_slice, probs_mx) # Evaluate both together | ||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||
| self.results.token_probs = np.array(probs_mx) | ||||||||||||||||||||||||||||||||||
| except (RuntimeError, TypeError, ValueError): | ||||||||||||||||||||||||||||||||||
| self.results.token_probs = np.array(probs_mx.tolist(), dtype=np.float32) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| mx.eval(logits_slice) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||
| self.results.logits = np.array(logits_slice) | ||||||||||||||||||||||||||||||||||
| except (RuntimeError, TypeError, ValueError): | ||||||||||||||||||||||||||||||||||
| self.results.logits = np.array(logits_slice.tolist(), dtype=np.float32) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Compute probabilities | ||||||||||||||||||||||||||||||||||
| if self.config.capture_token_probs: | ||||||||||||||||||||||||||||||||||
| max_logit = self.results.logits.max() | ||||||||||||||||||||||||||||||||||
| exp_logits = np.exp(self.results.logits - max_logit) | ||||||||||||||||||||||||||||||||||
| self.results.token_probs = exp_logits / exp_logits.sum() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Get top-k tokens with text | ||||||||||||||||||||||||||||||||||
| top_k = 20 | ||||||||||||||||||||||||||||||||||
| top_indices = np.argsort(self.results.logits)[-top_k:][::-1] | ||||||||||||||||||||||||||||||||||
|
|
@@ -2372,11 +2429,12 @@ def compute_norm(arr: np.ndarray) -> float: | |||||||||||||||||||||||||||||||||
| for pos in range(min(seq_len, self.config.max_sequence_positions, 512)): | ||||||||||||||||||||||||||||||||||
| if pos not in self.results.logit_lens_by_position: | ||||||||||||||||||||||||||||||||||
| self.results.logit_lens_by_position[pos] = {} | ||||||||||||||||||||||||||||||||||
| # For non-last positions, compute from full logits | ||||||||||||||||||||||||||||||||||
| pos_logits = np.array(logits[0, pos, :].astype(mx.float32)) | ||||||||||||||||||||||||||||||||||
| max_logit = pos_logits.max() | ||||||||||||||||||||||||||||||||||
| exp_logits = np.exp(pos_logits - max_logit) | ||||||||||||||||||||||||||||||||||
| pos_probs = exp_logits / exp_logits.sum() | ||||||||||||||||||||||||||||||||||
| # Compute softmax on GPU before converting to NumPy | ||||||||||||||||||||||||||||||||||
| pos_logits_mx = logits[0, pos, :].astype(mx.float32) | ||||||||||||||||||||||||||||||||||
| pos_probs_mx = mlx_softmax(pos_logits_mx) | ||||||||||||||||||||||||||||||||||
| mx.eval(pos_logits_mx, pos_probs_mx) | ||||||||||||||||||||||||||||||||||
| pos_logits = np.array(pos_logits_mx) | ||||||||||||||||||||||||||||||||||
| pos_probs = np.array(pos_probs_mx) | ||||||||||||||||||||||||||||||||||
| pos_top_indices = np.argsort(pos_logits)[-5:][::-1] | ||||||||||||||||||||||||||||||||||
| pos_preds = [] | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+2433
to
2439
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (performance): Explicitly free per-position MLX tensors inside the loop to reduce peak GPU memory usage. Within this per-position loop,
Suggested change
|
||||||||||||||||||||||||||||||||||
| for idx in pos_top_indices: | ||||||||||||||||||||||||||||||||||
|
|
@@ -4568,24 +4626,24 @@ def plot_token_probabilities(results: ProbeResults) -> go.Figure: | |||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def plot_layer_similarity(results: ProbeResults) -> go.Figure: | ||||||||||||||||||||||||||||||||||
| """Plot cosine similarity between layers.""" | ||||||||||||||||||||||||||||||||||
| """Plot cosine similarity between layers using GPU-accelerated matrix multiplication.""" | ||||||||||||||||||||||||||||||||||
| layer_indices = sorted(results.layer_outputs.keys()) | ||||||||||||||||||||||||||||||||||
| n = len(layer_indices) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if n < 2: | ||||||||||||||||||||||||||||||||||
| return go.Figure() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| similarity = np.zeros((n, n)) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Collect vectors and convert to MLX for GPU computation | ||||||||||||||||||||||||||||||||||
| vectors = [] | ||||||||||||||||||||||||||||||||||
| for idx in layer_indices: | ||||||||||||||||||||||||||||||||||
| vec = results.layer_outputs[idx].flatten() | ||||||||||||||||||||||||||||||||||
| vec = vec / (np.linalg.norm(vec) + 1e-8) | ||||||||||||||||||||||||||||||||||
| vectors.append(vec) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| for i in range(n): | ||||||||||||||||||||||||||||||||||
| for j in range(n): | ||||||||||||||||||||||||||||||||||
| similarity[i, j] = np.dot(vectors[i], vectors[j]) | ||||||||||||||||||||||||||||||||||
| # Stack into matrix and compute similarity on GPU (eliminates O(n²) loop) | ||||||||||||||||||||||||||||||||||
| vectors_mx = mx.array(np.stack(vectors)) | ||||||||||||||||||||||||||||||||||
| similarity_mx = mlx_cosine_similarity_matrix(vectors_mx) | ||||||||||||||||||||||||||||||||||
| mx.eval(similarity_mx) # Force evaluation before converting to NumPy | ||||||||||||||||||||||||||||||||||
| similarity = mlx_to_numpy(similarity_mx) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| labels = [f"L{idx}" for idx in layer_indices] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (bug_risk): Handle unexpected
basevalues explicitly inmlx_entropy.Any
basevalue other than "e" currently falls back tomx.log2, which can silently produce wrong results if a caller passes an unexpected value (e.g. "E", "10"). Consider validating/normalizing the argument (e.g. only allow {"e", "2"} and raiseValueErrorotherwise) so misconfiguration is caught instead of defaulting to log2.