Experimental: Replace NumPy with MLX for GPU acceleration#2
Experimental: Replace NumPy with MLX for GPU acceleration#2scouzi1966 wants to merge 2 commits intomainfrom
Conversation
Add MLX GPU-accelerated utility functions: - mlx_softmax() - GPU softmax - mlx_norm() - GPU L2 norm - mlx_entropy() - GPU entropy calculation - mlx_cosine_similarity_matrix() - batch matrix multiply Replace high-impact operations: - Softmax in Logit Lens (compute on GPU before NumPy conversion) - Softmax in token probability capture - Layer similarity matrix: O(n²) loop → single mx.matmul() These changes keep computation on GPU longer, reducing MLX↔NumPy transfers. Compatible with both Apple Silicon and mlx-cuda on Linux. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add summary of experimental/numpy-to-mlx-gpu branch to CLAUDE.md for historical reference. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Reviewer's GuideAdds MLX-based GPU-accelerated utility functions and rewires key softmax and similarity computations to run on the GPU before converting to NumPy, aiming to reduce CPU-bound work and unnecessary MLX↔NumPy transfers while preserving existing visualizations. Sequence diagram for GPU-accelerated layer similarity computationsequenceDiagram
participant UI as Caller_plot_layer_similarity
participant PLS as PlotLayerSimilarity
participant NP as NumPyRuntime
participant MXU as MLXUtilities
participant MX as MLXRuntime
UI->>PLS: plot_layer_similarity(results)
PLS->>PLS: collect layer_outputs as vectors
PLS->>NP: stack(vectors)
NP-->>PLS: vectors_np
PLS->>MXU: numpy_to_mlx(vectors_np)
MXU->>MX: array(vectors_np)
MX-->>MXU: vectors_mx
MXU-->>PLS: vectors_mx
PLS->>MXU: mlx_cosine_similarity_matrix(vectors_mx)
MXU->>MX: mlx_norm(vectors_mx)
MX-->>MXU: norms
MXU->>MX: matmul(normalized, normalized_T)
MX-->>MXU: similarity_mx
MXU-->>PLS: similarity_mx
PLS->>MX: eval(similarity_mx)
MX-->>PLS: realized_similarity_mx
PLS->>MXU: mlx_to_numpy(realized_similarity_mx)
MXU->>NP: array(realized_similarity_mx)
NP-->>MXU: similarity_np
MXU-->>PLS: similarity_np
PLS-->>UI: similarity matrix for plotting
Class diagram for new MLX GPU utility functionsclassDiagram
class MLXUtilities {
+mlx_softmax(x, axis)
+mlx_norm(x, axis, keepdims)
+mlx_entropy(probs, axis, base)
+mlx_cosine_similarity_matrix(vectors)
+mlx_to_numpy(x)
+numpy_to_mlx(x)
}
class MLXRuntime {
+softmax(x, axis)
+sqrt(x)
+sum(x, axis, keepdims)
+log(x)
+log2(x)
+matmul(a, b)
+array(x)
+eval(x)
}
class NumPyRuntime {
+array(x)
+stack(vectors)
+argsort(x)
+linalg_norm(x)
}
class PlotLayerSimilarity {
+plot_layer_similarity(results)
}
class LogitLensCapture {
+_capture_logit_lens(layer_idx, hidden_state)
}
class TokenProbsCapture {
+capture_token_probabilities()
}
MLXUtilities --> MLXRuntime : uses
MLXUtilities --> NumPyRuntime : converts
LogitLensCapture --> MLXUtilities : uses mlx_softmax
TokenProbsCapture --> MLXUtilities : uses mlx_softmax
PlotLayerSimilarity --> MLXUtilities : uses mlx_cosine_similarity_matrix
PlotLayerSimilarity --> NumPyRuntime : converts similarity to NumPy
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 3 issues, and left some high level feedback:
- The new
mlx_entropyhelper silently falls back tomx.log2for anybaseother than'e', which could hide misconfigurations — consider validatingbaseand raising for unsupported values. - Several places now convert MLX tensors to NumPy with
np.array(...)instead of the newmlx_to_numpyhelper; using the helper consistently would centralize any future changes to conversion and dtypes. - In the token probability and logit-lens paths, you now hold both logits and probabilities in MLX simultaneously before deletion; if GPU memory becomes tight, consider computing probabilities from a sliced or reduced view, or freeing logits before allocating
probs_mx.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The new `mlx_entropy` helper silently falls back to `mx.log2` for any `base` other than `'e'`, which could hide misconfigurations — consider validating `base` and raising for unsupported values.
- Several places now convert MLX tensors to NumPy with `np.array(...)` instead of the new `mlx_to_numpy` helper; using the helper consistently would centralize any future changes to conversion and dtypes.
- In the token probability and logit-lens paths, you now hold both logits and probabilities in MLX simultaneously before deletion; if GPU memory becomes tight, consider computing probabilities from a sliced or reduced view, or freeing logits before allocating `probs_mx`.
## Individual Comments
### Comment 1
<location> `mlxlmprobe.py:93-97` </location>
<code_context>
+ return mx.sqrt(mx.sum(x * x, axis=axis, keepdims=keepdims))
+
+
+def mlx_entropy(probs: mx.array, axis: int = -1, base: str = "e") -> mx.array:
+ """GPU-accelerated entropy calculation.
+
+ Args:
+ probs: Probability distribution (should sum to 1 along axis)
+ axis: Axis along which to compute entropy
+ base: "e" for natural log, "2" for log2
+ """
+ log_fn = mx.log if base == "e" else mx.log2
+ return -mx.sum(probs * log_fn(probs + 1e-10), axis=axis)
+
</code_context>
<issue_to_address>
**suggestion (bug_risk):** Handle unexpected `base` values explicitly in `mlx_entropy`.
Any `base` value other than "e" currently falls back to `mx.log2`, which can silently produce wrong results if a caller passes an unexpected value (e.g. "E", "10"). Consider validating/normalizing the argument (e.g. only allow {"e", "2"} and raise `ValueError` otherwise) so misconfiguration is caught instead of defaulting to log2.
```suggestion
def mlx_norm(x: mx.array, axis: int = -1, keepdims: bool = False) -> mx.array:
"""GPU-accelerated L2 norm."""
return mx.sqrt(mx.sum(x * x, axis=axis, keepdims=keepdims))
def mlx_entropy(probs: mx.array, axis: int = -1, base: str = "e") -> mx.array:
"""GPU-accelerated entropy calculation.
Args:
probs: Probability distribution (should sum to 1 along axis)
axis: Axis along which to compute entropy
base: "e" for natural log, "2" for log2
Raises:
ValueError: If `base` is not one of {"e", "2"}.
"""
if base == "e":
log_fn = mx.log
elif base == "2":
log_fn = mx.log2
else:
raise ValueError(f"Unsupported base for mlx_entropy: {base!r}. Expected 'e' or '2'.")
return -mx.sum(probs * log_fn(probs + 1e-10), axis=axis)
```
</issue_to_address>
### Comment 2
<location> `mlxlmprobe.py:2433-2439` </location>
<code_context>
- max_logit = pos_logits.max()
- exp_logits = np.exp(pos_logits - max_logit)
- pos_probs = exp_logits / exp_logits.sum()
+ # Compute softmax on GPU before converting to NumPy
+ pos_logits_mx = logits[0, pos, :].astype(mx.float32)
+ pos_probs_mx = mlx_softmax(pos_logits_mx)
+ mx.eval(pos_logits_mx, pos_probs_mx)
+ pos_logits = np.array(pos_logits_mx)
+ pos_probs = np.array(pos_probs_mx)
pos_top_indices = np.argsort(pos_logits)[-5:][::-1]
pos_preds = []
</code_context>
<issue_to_address>
**suggestion (performance):** Explicitly free per-position MLX tensors inside the loop to reduce peak GPU memory usage.
Within this per-position loop, `pos_logits_mx` and `pos_probs_mx` are reallocated each iteration and only freed when the loop exits. For long sequences and many layers, this can spike peak GPU memory. Similar to the earlier `del logits_pos, probs_mx`, consider adding `del pos_logits_mx, pos_probs_mx` at the end of the loop body to keep GPU memory usage lower during logit lens capture.
```suggestion
pos_logits_mx = logits[0, pos, :].astype(mx.float32)
pos_probs_mx = mlx_softmax(pos_logits_mx)
mx.eval(pos_logits_mx, pos_probs_mx)
pos_logits = np.array(pos_logits_mx)
pos_probs = np.array(pos_probs_mx)
# Free per-position MLX tensors to reduce peak GPU memory usage
del pos_logits_mx, pos_probs_mx
pos_top_indices = np.argsort(pos_logits)[-5:][::-1]
pos_preds = []
```
</issue_to_address>
### Comment 3
<location> `mlxlmprobe.py:1923` </location>
<code_context>
+ logits_pos = layer_logits[0, pos, :].astype(mx.float32)
+
+ # Compute probabilities on GPU before converting to NumPy
+ probs_mx = mlx_softmax(logits_pos)
+ mx.eval(logits_pos, probs_mx) # Evaluate both together
</code_context>
<issue_to_address>
**issue (complexity):** Consider extracting shared MLX→NumPy conversion and GPU-handling into small helpers so the probing and plotting code stays simpler and more focused on NumPy-level logic.
You can reduce the added complexity by centralizing the MLX→NumPy conversion and separating GPU concerns from the probing logic.
### 1. Factor out a safe MLX→NumPy helper
The `try/except` + `tolist()` pattern is duplicated and mixes evaluation with conversion:
```python
probs_mx = mlx_softmax(logits_pos)
mx.eval(logits_pos, probs_mx)
try:
logits_np = np.array(logits_pos)
probs = np.array(probs_mx)
except (RuntimeError, TypeError, ValueError):
logits_np = np.array(logits_pos.tolist(), dtype=np.float32)
probs = np.array(probs_mx.tolist(), dtype=np.float32)
```
Introduce a small helper that you can reuse everywhere:
```python
def safe_mlx_to_numpy(x: mx.array, dtype=np.float32) -> np.ndarray:
mx.eval(x)
try:
return np.array(x, dtype=dtype)
except (RuntimeError, TypeError, ValueError):
return np.array(x.tolist(), dtype=dtype)
```
Then simplify the call sites:
```python
logits_pos = layer_logits[0, pos, :].astype(mx.float32)
probs_mx = mlx_softmax(logits_pos)
logits_np = safe_mlx_to_numpy(logits_pos)
probs = safe_mlx_to_numpy(probs_mx)
del logits_pos, probs_mx
```
And similarly in `_capture_logit_lens`:
```python
logits_slice = logits[0, -1, :].astype(mx.float32)
if self.config.capture_token_probs:
probs_mx = mlx_softmax(logits_slice)
self.results.token_probs = safe_mlx_to_numpy(probs_mx)
self.results.logits = safe_mlx_to_numpy(logits_slice)
```
And in the full-sequence block:
```python
pos_logits_mx = logits[0, pos, :].astype(mx.float32)
pos_probs_mx = mlx_softmax(pos_logits_mx)
pos_logits = safe_mlx_to_numpy(pos_logits_mx)
pos_probs = safe_mlx_to_numpy(pos_probs_mx)
```
This keeps behavior identical but removes repeated `try/except` and evaluation boilerplate.
### 2. Separate “what” from “how” in the probing logic
Right now `_capture_logit_lens` interleaves:
- computing softmax on GPU,
- evaluating MLX tensors,
- converting to NumPy,
- assigning to `self.results`.
You can isolate the GPU handling into a tiny utility and keep probing logic focused on results:
```python
def compute_logits_and_probs(
logits: mx.array, capture_probs: bool
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
logits_slice = logits[0, -1, :].astype(mx.float32)
if capture_probs:
probs_mx = mlx_softmax(logits_slice)
logits_np = safe_mlx_to_numpy(logits_slice)
probs_np = safe_mlx_to_numpy(probs_mx)
return logits_np, probs_np
else:
logits_np = safe_mlx_to_numpy(logits_slice)
return logits_np, None
```
Usage in `_capture_logit_lens`:
```python
self.results.logits, self.results.token_probs = compute_logits_and_probs(
logits, self.config.capture_token_probs
)
```
The rest of the method can then work purely with NumPy and business logic.
### 3. Simplify `plot_layer_similarity` GPU path
You already have `mlx_to_numpy` and don’t need the extra staging via NumPy stacking:
```python
vectors_mx = mx.stack(
[mx.array(results.layer_outputs[idx].flatten()) for idx in layer_indices]
)
similarity_mx = mlx_cosine_similarity_matrix(vectors_mx)
similarity = mlx_to_numpy(similarity_mx)
```
This removes the explicit `np.zeros` initialization, the manual `vectors` list plus `np.stack`, and the separate `mx.eval(similarity_mx)` call (since `mlx_to_numpy` can handle evaluation if you fold that into it in the future).
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| def mlx_norm(x: mx.array, axis: int = -1, keepdims: bool = False) -> mx.array: | ||
| """GPU-accelerated L2 norm.""" | ||
| return mx.sqrt(mx.sum(x * x, axis=axis, keepdims=keepdims)) | ||
|
|
||
|
|
There was a problem hiding this comment.
suggestion (bug_risk): Handle unexpected base values explicitly in mlx_entropy.
Any base value other than "e" currently falls back to mx.log2, which can silently produce wrong results if a caller passes an unexpected value (e.g. "E", "10"). Consider validating/normalizing the argument (e.g. only allow {"e", "2"} and raise ValueError otherwise) so misconfiguration is caught instead of defaulting to log2.
| def mlx_norm(x: mx.array, axis: int = -1, keepdims: bool = False) -> mx.array: | |
| """GPU-accelerated L2 norm.""" | |
| return mx.sqrt(mx.sum(x * x, axis=axis, keepdims=keepdims)) | |
| def mlx_norm(x: mx.array, axis: int = -1, keepdims: bool = False) -> mx.array: | |
| """GPU-accelerated L2 norm.""" | |
| return mx.sqrt(mx.sum(x * x, axis=axis, keepdims=keepdims)) | |
| def mlx_entropy(probs: mx.array, axis: int = -1, base: str = "e") -> mx.array: | |
| """GPU-accelerated entropy calculation. | |
| Args: | |
| probs: Probability distribution (should sum to 1 along axis) | |
| axis: Axis along which to compute entropy | |
| base: "e" for natural log, "2" for log2 | |
| Raises: | |
| ValueError: If `base` is not one of {"e", "2"}. | |
| """ | |
| if base == "e": | |
| log_fn = mx.log | |
| elif base == "2": | |
| log_fn = mx.log2 | |
| else: | |
| raise ValueError(f"Unsupported base for mlx_entropy: {base!r}. Expected 'e' or '2'.") | |
| return -mx.sum(probs * log_fn(probs + 1e-10), axis=axis) |
| pos_logits_mx = logits[0, pos, :].astype(mx.float32) | ||
| pos_probs_mx = mlx_softmax(pos_logits_mx) | ||
| mx.eval(pos_logits_mx, pos_probs_mx) | ||
| pos_logits = np.array(pos_logits_mx) | ||
| pos_probs = np.array(pos_probs_mx) | ||
| pos_top_indices = np.argsort(pos_logits)[-5:][::-1] | ||
| pos_preds = [] |
There was a problem hiding this comment.
suggestion (performance): Explicitly free per-position MLX tensors inside the loop to reduce peak GPU memory usage.
Within this per-position loop, pos_logits_mx and pos_probs_mx are reallocated each iteration and only freed when the loop exits. For long sequences and many layers, this can spike peak GPU memory. Similar to the earlier del logits_pos, probs_mx, consider adding del pos_logits_mx, pos_probs_mx at the end of the loop body to keep GPU memory usage lower during logit lens capture.
| pos_logits_mx = logits[0, pos, :].astype(mx.float32) | |
| pos_probs_mx = mlx_softmax(pos_logits_mx) | |
| mx.eval(pos_logits_mx, pos_probs_mx) | |
| pos_logits = np.array(pos_logits_mx) | |
| pos_probs = np.array(pos_probs_mx) | |
| pos_top_indices = np.argsort(pos_logits)[-5:][::-1] | |
| pos_preds = [] | |
| pos_logits_mx = logits[0, pos, :].astype(mx.float32) | |
| pos_probs_mx = mlx_softmax(pos_logits_mx) | |
| mx.eval(pos_logits_mx, pos_probs_mx) | |
| pos_logits = np.array(pos_logits_mx) | |
| pos_probs = np.array(pos_probs_mx) | |
| # Free per-position MLX tensors to reduce peak GPU memory usage | |
| del pos_logits_mx, pos_probs_mx | |
| pos_top_indices = np.argsort(pos_logits)[-5:][::-1] | |
| pos_preds = [] |
| logits_pos = layer_logits[0, pos, :].astype(mx.float32) | ||
|
|
||
| # Compute probabilities on GPU before converting to NumPy | ||
| probs_mx = mlx_softmax(logits_pos) |
There was a problem hiding this comment.
issue (complexity): Consider extracting shared MLX→NumPy conversion and GPU-handling into small helpers so the probing and plotting code stays simpler and more focused on NumPy-level logic.
You can reduce the added complexity by centralizing the MLX→NumPy conversion and separating GPU concerns from the probing logic.
1. Factor out a safe MLX→NumPy helper
The try/except + tolist() pattern is duplicated and mixes evaluation with conversion:
probs_mx = mlx_softmax(logits_pos)
mx.eval(logits_pos, probs_mx)
try:
logits_np = np.array(logits_pos)
probs = np.array(probs_mx)
except (RuntimeError, TypeError, ValueError):
logits_np = np.array(logits_pos.tolist(), dtype=np.float32)
probs = np.array(probs_mx.tolist(), dtype=np.float32)Introduce a small helper that you can reuse everywhere:
def safe_mlx_to_numpy(x: mx.array, dtype=np.float32) -> np.ndarray:
mx.eval(x)
try:
return np.array(x, dtype=dtype)
except (RuntimeError, TypeError, ValueError):
return np.array(x.tolist(), dtype=dtype)Then simplify the call sites:
logits_pos = layer_logits[0, pos, :].astype(mx.float32)
probs_mx = mlx_softmax(logits_pos)
logits_np = safe_mlx_to_numpy(logits_pos)
probs = safe_mlx_to_numpy(probs_mx)
del logits_pos, probs_mxAnd similarly in _capture_logit_lens:
logits_slice = logits[0, -1, :].astype(mx.float32)
if self.config.capture_token_probs:
probs_mx = mlx_softmax(logits_slice)
self.results.token_probs = safe_mlx_to_numpy(probs_mx)
self.results.logits = safe_mlx_to_numpy(logits_slice)And in the full-sequence block:
pos_logits_mx = logits[0, pos, :].astype(mx.float32)
pos_probs_mx = mlx_softmax(pos_logits_mx)
pos_logits = safe_mlx_to_numpy(pos_logits_mx)
pos_probs = safe_mlx_to_numpy(pos_probs_mx)This keeps behavior identical but removes repeated try/except and evaluation boilerplate.
2. Separate “what” from “how” in the probing logic
Right now _capture_logit_lens interleaves:
- computing softmax on GPU,
- evaluating MLX tensors,
- converting to NumPy,
- assigning to
self.results.
You can isolate the GPU handling into a tiny utility and keep probing logic focused on results:
def compute_logits_and_probs(
logits: mx.array, capture_probs: bool
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
logits_slice = logits[0, -1, :].astype(mx.float32)
if capture_probs:
probs_mx = mlx_softmax(logits_slice)
logits_np = safe_mlx_to_numpy(logits_slice)
probs_np = safe_mlx_to_numpy(probs_mx)
return logits_np, probs_np
else:
logits_np = safe_mlx_to_numpy(logits_slice)
return logits_np, NoneUsage in _capture_logit_lens:
self.results.logits, self.results.token_probs = compute_logits_and_probs(
logits, self.config.capture_token_probs
)The rest of the method can then work purely with NumPy and business logic.
3. Simplify plot_layer_similarity GPU path
You already have mlx_to_numpy and don’t need the extra staging via NumPy stacking:
vectors_mx = mx.stack(
[mx.array(results.layer_outputs[idx].flatten()) for idx in layer_indices]
)
similarity_mx = mlx_cosine_similarity_matrix(vectors_mx)
similarity = mlx_to_numpy(similarity_mx)This removes the explicit np.zeros initialization, the manual vectors list plus np.stack, and the separate mx.eval(similarity_mx) call (since mlx_to_numpy can handle evaluation if you fold that into it in the future).
Summary
mlx_softmax,mlx_norm,mlx_entropy,mlx_cosine_similarity_matrix)mx.matmul()callImpact
mx.matmul()Compatibility
Works on both Apple Silicon (MLX) and Linux (mlx-cuda).
Test plan
🤖 Generated with Claude Code
Summary by Sourcery
Introduce MLX-based GPU-accelerated helpers and use them for key probability and similarity computations to reduce CPU-bound NumPy work and improve performance.
New Features:
Enhancements:
Documentation: