Skip to content

Experimental: Replace NumPy with MLX for GPU acceleration#2

Open
scouzi1966 wants to merge 2 commits intomainfrom
experimental/numpy-to-mlx-gpu
Open

Experimental: Replace NumPy with MLX for GPU acceleration#2
scouzi1966 wants to merge 2 commits intomainfrom
experimental/numpy-to-mlx-gpu

Conversation

@scouzi1966
Copy link
Copy Markdown
Owner

@scouzi1966 scouzi1966 commented Jan 28, 2026

Summary

  • Add MLX GPU-accelerated utility functions (mlx_softmax, mlx_norm, mlx_entropy, mlx_cosine_similarity_matrix)
  • Replace softmax computations to run on GPU before NumPy conversion (3 locations)
  • Replace layer similarity O(n²) nested loop with single mx.matmul() call

Impact

Location Operation Change
Logit Lens Softmax GPU before NumPy conversion
Token Probs Softmax GPU before NumPy conversion
Full Sequence Logit Lens Softmax GPU before NumPy conversion
Layer Similarity O(n²) cosine similarity Single mx.matmul()

Compatibility

Works on both Apple Silicon (MLX) and Linux (mlx-cuda).

Test plan

  • Run with GPT-OSS-20B and compare timing before/after
  • Verify numerical results are equivalent (within float tolerance)
  • Check memory usage doesn't increase
  • Ensure Plotly visualizations still work correctly

🤖 Generated with Claude Code

Summary by Sourcery

Introduce MLX-based GPU-accelerated helpers and use them for key probability and similarity computations to reduce CPU-bound NumPy work and improve performance.

New Features:

  • Add MLX GPU utility functions for softmax, L2 norm, entropy, cosine similarity matrix, and MLX↔NumPy conversions.

Enhancements:

  • Compute logit lens and token probability softmax operations on GPU before converting results to NumPy for visualization.
  • Replace the O(n²) Python loop for layer-wise cosine similarity with a single GPU-accelerated matrix multiplication via MLX.
  • Document the experimental NumPy-to-MLX GPU acceleration approach and impacted code paths in a new project context file.

Documentation:

  • Add CLAUDE.md documenting the experimental MLX GPU acceleration strategy, new utilities, and targeted hotspots.

scouzi1966 and others added 2 commits January 28, 2026 10:37
Add MLX GPU-accelerated utility functions:
- mlx_softmax() - GPU softmax
- mlx_norm() - GPU L2 norm
- mlx_entropy() - GPU entropy calculation
- mlx_cosine_similarity_matrix() - batch matrix multiply

Replace high-impact operations:
- Softmax in Logit Lens (compute on GPU before NumPy conversion)
- Softmax in token probability capture
- Layer similarity matrix: O(n²) loop → single mx.matmul()

These changes keep computation on GPU longer, reducing MLX↔NumPy
transfers. Compatible with both Apple Silicon and mlx-cuda on Linux.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add summary of experimental/numpy-to-mlx-gpu branch to CLAUDE.md
for historical reference.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@sourcery-ai
Copy link
Copy Markdown

sourcery-ai bot commented Jan 28, 2026

Reviewer's Guide

Adds MLX-based GPU-accelerated utility functions and rewires key softmax and similarity computations to run on the GPU before converting to NumPy, aiming to reduce CPU-bound work and unnecessary MLX↔NumPy transfers while preserving existing visualizations.

Sequence diagram for GPU-accelerated layer similarity computation

sequenceDiagram
    participant UI as Caller_plot_layer_similarity
    participant PLS as PlotLayerSimilarity
    participant NP as NumPyRuntime
    participant MXU as MLXUtilities
    participant MX as MLXRuntime

    UI->>PLS: plot_layer_similarity(results)
    PLS->>PLS: collect layer_outputs as vectors
    PLS->>NP: stack(vectors)
    NP-->>PLS: vectors_np

    PLS->>MXU: numpy_to_mlx(vectors_np)
    MXU->>MX: array(vectors_np)
    MX-->>MXU: vectors_mx
    MXU-->>PLS: vectors_mx

    PLS->>MXU: mlx_cosine_similarity_matrix(vectors_mx)
    MXU->>MX: mlx_norm(vectors_mx)
    MX-->>MXU: norms
    MXU->>MX: matmul(normalized, normalized_T)
    MX-->>MXU: similarity_mx
    MXU-->>PLS: similarity_mx

    PLS->>MX: eval(similarity_mx)
    MX-->>PLS: realized_similarity_mx

    PLS->>MXU: mlx_to_numpy(realized_similarity_mx)
    MXU->>NP: array(realized_similarity_mx)
    NP-->>MXU: similarity_np
    MXU-->>PLS: similarity_np

    PLS-->>UI: similarity matrix for plotting
Loading

Class diagram for new MLX GPU utility functions

classDiagram
    class MLXUtilities {
        +mlx_softmax(x, axis)
        +mlx_norm(x, axis, keepdims)
        +mlx_entropy(probs, axis, base)
        +mlx_cosine_similarity_matrix(vectors)
        +mlx_to_numpy(x)
        +numpy_to_mlx(x)
    }

    class MLXRuntime {
        +softmax(x, axis)
        +sqrt(x)
        +sum(x, axis, keepdims)
        +log(x)
        +log2(x)
        +matmul(a, b)
        +array(x)
        +eval(x)
    }

    class NumPyRuntime {
        +array(x)
        +stack(vectors)
        +argsort(x)
        +linalg_norm(x)
    }

    class PlotLayerSimilarity {
        +plot_layer_similarity(results)
    }

    class LogitLensCapture {
        +_capture_logit_lens(layer_idx, hidden_state)
    }

    class TokenProbsCapture {
        +capture_token_probabilities()
    }

    MLXUtilities --> MLXRuntime : uses
    MLXUtilities --> NumPyRuntime : converts

    LogitLensCapture --> MLXUtilities : uses mlx_softmax
    TokenProbsCapture --> MLXUtilities : uses mlx_softmax
    PlotLayerSimilarity --> MLXUtilities : uses mlx_cosine_similarity_matrix
    PlotLayerSimilarity --> NumPyRuntime : converts similarity to NumPy
Loading

File-Level Changes

Change Details Files
Introduce MLX GPU utility layer for softmax, norms, entropy, cosine similarity, and array conversions.
  • Add mlx_softmax, mlx_norm, mlx_entropy, and mlx_cosine_similarity_matrix helpers wrapping mx operations for GPU execution.
  • Add mlx_to_numpy and numpy_to_mlx conversion helpers to centralize MLX↔NumPy conversions.
  • Document these utilities and their purpose as GPU-accelerated primitives in mlxlmprobe.py.
mlxlmprobe.py
CLAUDE.md
Refactor logit-lens-related softmax computations to run on GPU prior to NumPy conversion.
  • In _capture_logit_lens, compute probabilities via mlx_softmax on mx arrays and then convert logits and probs to NumPy for downstream use, freeing MLX tensors explicitly.
  • In token probability capture, compute softmax on GPU when capture_token_probs is enabled and store NumPy-converted probabilities in results.
  • For full-sequence logit lens, compute per-position softmax using mlx_softmax on GPU, then convert logits and probabilities to NumPy before ranking top predictions.
mlxlmprobe.py
Replace O(n²) NumPy-based layer similarity loop with a single GPU-accelerated matrix multiplication.
  • Collect layer output vectors, stack them, and convert to an mx array for GPU computation.
  • Normalize vectors and compute pairwise cosine similarity via mlx_cosine_similarity_matrix using mx.matmul rather than nested Python loops.
  • Convert resulting similarity matrix back to NumPy for Plotly visualization.
mlxlmprobe.py
Add project context and documentation for the experimental NumPy→MLX GPU acceleration effort.
  • Create CLAUDE.md describing the experimental branch, goals, new MLX utilities, and high-impact computation sites.
  • Clarify which operations remain NumPy-based due to data format or visualization requirements and outline expected benefits and test plan.
CLAUDE.md

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

Copy link
Copy Markdown

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've found 3 issues, and left some high level feedback:

  • The new mlx_entropy helper silently falls back to mx.log2 for any base other than 'e', which could hide misconfigurations — consider validating base and raising for unsupported values.
  • Several places now convert MLX tensors to NumPy with np.array(...) instead of the new mlx_to_numpy helper; using the helper consistently would centralize any future changes to conversion and dtypes.
  • In the token probability and logit-lens paths, you now hold both logits and probabilities in MLX simultaneously before deletion; if GPU memory becomes tight, consider computing probabilities from a sliced or reduced view, or freeing logits before allocating probs_mx.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- The new `mlx_entropy` helper silently falls back to `mx.log2` for any `base` other than `'e'`, which could hide misconfigurations — consider validating `base` and raising for unsupported values.
- Several places now convert MLX tensors to NumPy with `np.array(...)` instead of the new `mlx_to_numpy` helper; using the helper consistently would centralize any future changes to conversion and dtypes.
- In the token probability and logit-lens paths, you now hold both logits and probabilities in MLX simultaneously before deletion; if GPU memory becomes tight, consider computing probabilities from a sliced or reduced view, or freeing logits before allocating `probs_mx`.

## Individual Comments

### Comment 1
<location> `mlxlmprobe.py:93-97` </location>
<code_context>
+    return mx.sqrt(mx.sum(x * x, axis=axis, keepdims=keepdims))
+
+
+def mlx_entropy(probs: mx.array, axis: int = -1, base: str = "e") -> mx.array:
+    """GPU-accelerated entropy calculation.
+
+    Args:
+        probs: Probability distribution (should sum to 1 along axis)
+        axis: Axis along which to compute entropy
+        base: "e" for natural log, "2" for log2
+    """
+    log_fn = mx.log if base == "e" else mx.log2
+    return -mx.sum(probs * log_fn(probs + 1e-10), axis=axis)
+
</code_context>

<issue_to_address>
**suggestion (bug_risk):** Handle unexpected `base` values explicitly in `mlx_entropy`.

Any `base` value other than "e" currently falls back to `mx.log2`, which can silently produce wrong results if a caller passes an unexpected value (e.g. "E", "10"). Consider validating/normalizing the argument (e.g. only allow {"e", "2"} and raise `ValueError` otherwise) so misconfiguration is caught instead of defaulting to log2.

```suggestion
def mlx_norm(x: mx.array, axis: int = -1, keepdims: bool = False) -> mx.array:
    """GPU-accelerated L2 norm."""
    return mx.sqrt(mx.sum(x * x, axis=axis, keepdims=keepdims))


def mlx_entropy(probs: mx.array, axis: int = -1, base: str = "e") -> mx.array:
    """GPU-accelerated entropy calculation.

    Args:
        probs: Probability distribution (should sum to 1 along axis)
        axis: Axis along which to compute entropy
        base: "e" for natural log, "2" for log2

    Raises:
        ValueError: If `base` is not one of {"e", "2"}.
    """
    if base == "e":
        log_fn = mx.log
    elif base == "2":
        log_fn = mx.log2
    else:
        raise ValueError(f"Unsupported base for mlx_entropy: {base!r}. Expected 'e' or '2'.")
    return -mx.sum(probs * log_fn(probs + 1e-10), axis=axis)
```
</issue_to_address>

### Comment 2
<location> `mlxlmprobe.py:2433-2439` </location>
<code_context>
-                        max_logit = pos_logits.max()
-                        exp_logits = np.exp(pos_logits - max_logit)
-                        pos_probs = exp_logits / exp_logits.sum()
+                        # Compute softmax on GPU before converting to NumPy
+                        pos_logits_mx = logits[0, pos, :].astype(mx.float32)
+                        pos_probs_mx = mlx_softmax(pos_logits_mx)
+                        mx.eval(pos_logits_mx, pos_probs_mx)
+                        pos_logits = np.array(pos_logits_mx)
+                        pos_probs = np.array(pos_probs_mx)
                         pos_top_indices = np.argsort(pos_logits)[-5:][::-1]
                         pos_preds = []
</code_context>

<issue_to_address>
**suggestion (performance):** Explicitly free per-position MLX tensors inside the loop to reduce peak GPU memory usage.

Within this per-position loop, `pos_logits_mx` and `pos_probs_mx` are reallocated each iteration and only freed when the loop exits. For long sequences and many layers, this can spike peak GPU memory. Similar to the earlier `del logits_pos, probs_mx`, consider adding `del pos_logits_mx, pos_probs_mx` at the end of the loop body to keep GPU memory usage lower during logit lens capture.

```suggestion
                        pos_logits_mx = logits[0, pos, :].astype(mx.float32)
                        pos_probs_mx = mlx_softmax(pos_logits_mx)
                        mx.eval(pos_logits_mx, pos_probs_mx)
                        pos_logits = np.array(pos_logits_mx)
                        pos_probs = np.array(pos_probs_mx)
                        # Free per-position MLX tensors to reduce peak GPU memory usage
                        del pos_logits_mx, pos_probs_mx
                        pos_top_indices = np.argsort(pos_logits)[-5:][::-1]
                        pos_preds = []
```
</issue_to_address>

### Comment 3
<location> `mlxlmprobe.py:1923` </location>
<code_context>
+                logits_pos = layer_logits[0, pos, :].astype(mx.float32)
+
+                # Compute probabilities on GPU before converting to NumPy
+                probs_mx = mlx_softmax(logits_pos)
+                mx.eval(logits_pos, probs_mx)  # Evaluate both together

</code_context>

<issue_to_address>
**issue (complexity):** Consider extracting shared MLX→NumPy conversion and GPU-handling into small helpers so the probing and plotting code stays simpler and more focused on NumPy-level logic.

You can reduce the added complexity by centralizing the MLX→NumPy conversion and separating GPU concerns from the probing logic.

### 1. Factor out a safe MLX→NumPy helper

The `try/except` + `tolist()` pattern is duplicated and mixes evaluation with conversion:

```python
probs_mx = mlx_softmax(logits_pos)
mx.eval(logits_pos, probs_mx)
try:
    logits_np = np.array(logits_pos)
    probs = np.array(probs_mx)
except (RuntimeError, TypeError, ValueError):
    logits_np = np.array(logits_pos.tolist(), dtype=np.float32)
    probs = np.array(probs_mx.tolist(), dtype=np.float32)
```

Introduce a small helper that you can reuse everywhere:

```python
def safe_mlx_to_numpy(x: mx.array, dtype=np.float32) -> np.ndarray:
    mx.eval(x)
    try:
        return np.array(x, dtype=dtype)
    except (RuntimeError, TypeError, ValueError):
        return np.array(x.tolist(), dtype=dtype)
```

Then simplify the call sites:

```python
logits_pos = layer_logits[0, pos, :].astype(mx.float32)
probs_mx = mlx_softmax(logits_pos)

logits_np = safe_mlx_to_numpy(logits_pos)
probs = safe_mlx_to_numpy(probs_mx)

del logits_pos, probs_mx
```

And similarly in `_capture_logit_lens`:

```python
logits_slice = logits[0, -1, :].astype(mx.float32)

if self.config.capture_token_probs:
    probs_mx = mlx_softmax(logits_slice)
    self.results.token_probs = safe_mlx_to_numpy(probs_mx)

self.results.logits = safe_mlx_to_numpy(logits_slice)
```

And in the full-sequence block:

```python
pos_logits_mx = logits[0, pos, :].astype(mx.float32)
pos_probs_mx = mlx_softmax(pos_logits_mx)

pos_logits = safe_mlx_to_numpy(pos_logits_mx)
pos_probs = safe_mlx_to_numpy(pos_probs_mx)
```

This keeps behavior identical but removes repeated `try/except` and evaluation boilerplate.

### 2. Separate “what” from “how” in the probing logic

Right now `_capture_logit_lens` interleaves:

- computing softmax on GPU,
- evaluating MLX tensors,
- converting to NumPy,
- assigning to `self.results`.

You can isolate the GPU handling into a tiny utility and keep probing logic focused on results:

```python
def compute_logits_and_probs(
    logits: mx.array, capture_probs: bool
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
    logits_slice = logits[0, -1, :].astype(mx.float32)
    if capture_probs:
        probs_mx = mlx_softmax(logits_slice)
        logits_np = safe_mlx_to_numpy(logits_slice)
        probs_np = safe_mlx_to_numpy(probs_mx)
        return logits_np, probs_np
    else:
        logits_np = safe_mlx_to_numpy(logits_slice)
        return logits_np, None
```

Usage in `_capture_logit_lens`:

```python
self.results.logits, self.results.token_probs = compute_logits_and_probs(
    logits, self.config.capture_token_probs
)
```

The rest of the method can then work purely with NumPy and business logic.

### 3. Simplify `plot_layer_similarity` GPU path

You already have `mlx_to_numpy` and don’t need the extra staging via NumPy stacking:

```python
vectors_mx = mx.stack(
    [mx.array(results.layer_outputs[idx].flatten()) for idx in layer_indices]
)
similarity_mx = mlx_cosine_similarity_matrix(vectors_mx)
similarity = mlx_to_numpy(similarity_mx)
```

This removes the explicit `np.zeros` initialization, the manual `vectors` list plus `np.stack`, and the separate `mx.eval(similarity_mx)` call (since `mlx_to_numpy` can handle evaluation if you fold that into it in the future).
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment thread mlxlmprobe.py
Comment on lines +93 to +97
def mlx_norm(x: mx.array, axis: int = -1, keepdims: bool = False) -> mx.array:
"""GPU-accelerated L2 norm."""
return mx.sqrt(mx.sum(x * x, axis=axis, keepdims=keepdims))


Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Handle unexpected base values explicitly in mlx_entropy.

Any base value other than "e" currently falls back to mx.log2, which can silently produce wrong results if a caller passes an unexpected value (e.g. "E", "10"). Consider validating/normalizing the argument (e.g. only allow {"e", "2"} and raise ValueError otherwise) so misconfiguration is caught instead of defaulting to log2.

Suggested change
def mlx_norm(x: mx.array, axis: int = -1, keepdims: bool = False) -> mx.array:
"""GPU-accelerated L2 norm."""
return mx.sqrt(mx.sum(x * x, axis=axis, keepdims=keepdims))
def mlx_norm(x: mx.array, axis: int = -1, keepdims: bool = False) -> mx.array:
"""GPU-accelerated L2 norm."""
return mx.sqrt(mx.sum(x * x, axis=axis, keepdims=keepdims))
def mlx_entropy(probs: mx.array, axis: int = -1, base: str = "e") -> mx.array:
"""GPU-accelerated entropy calculation.
Args:
probs: Probability distribution (should sum to 1 along axis)
axis: Axis along which to compute entropy
base: "e" for natural log, "2" for log2
Raises:
ValueError: If `base` is not one of {"e", "2"}.
"""
if base == "e":
log_fn = mx.log
elif base == "2":
log_fn = mx.log2
else:
raise ValueError(f"Unsupported base for mlx_entropy: {base!r}. Expected 'e' or '2'.")
return -mx.sum(probs * log_fn(probs + 1e-10), axis=axis)

Comment thread mlxlmprobe.py
Comment on lines +2433 to 2439
pos_logits_mx = logits[0, pos, :].astype(mx.float32)
pos_probs_mx = mlx_softmax(pos_logits_mx)
mx.eval(pos_logits_mx, pos_probs_mx)
pos_logits = np.array(pos_logits_mx)
pos_probs = np.array(pos_probs_mx)
pos_top_indices = np.argsort(pos_logits)[-5:][::-1]
pos_preds = []
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Explicitly free per-position MLX tensors inside the loop to reduce peak GPU memory usage.

Within this per-position loop, pos_logits_mx and pos_probs_mx are reallocated each iteration and only freed when the loop exits. For long sequences and many layers, this can spike peak GPU memory. Similar to the earlier del logits_pos, probs_mx, consider adding del pos_logits_mx, pos_probs_mx at the end of the loop body to keep GPU memory usage lower during logit lens capture.

Suggested change
pos_logits_mx = logits[0, pos, :].astype(mx.float32)
pos_probs_mx = mlx_softmax(pos_logits_mx)
mx.eval(pos_logits_mx, pos_probs_mx)
pos_logits = np.array(pos_logits_mx)
pos_probs = np.array(pos_probs_mx)
pos_top_indices = np.argsort(pos_logits)[-5:][::-1]
pos_preds = []
pos_logits_mx = logits[0, pos, :].astype(mx.float32)
pos_probs_mx = mlx_softmax(pos_logits_mx)
mx.eval(pos_logits_mx, pos_probs_mx)
pos_logits = np.array(pos_logits_mx)
pos_probs = np.array(pos_probs_mx)
# Free per-position MLX tensors to reduce peak GPU memory usage
del pos_logits_mx, pos_probs_mx
pos_top_indices = np.argsort(pos_logits)[-5:][::-1]
pos_preds = []

Comment thread mlxlmprobe.py
logits_pos = layer_logits[0, pos, :].astype(mx.float32)

# Compute probabilities on GPU before converting to NumPy
probs_mx = mlx_softmax(logits_pos)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider extracting shared MLX→NumPy conversion and GPU-handling into small helpers so the probing and plotting code stays simpler and more focused on NumPy-level logic.

You can reduce the added complexity by centralizing the MLX→NumPy conversion and separating GPU concerns from the probing logic.

1. Factor out a safe MLX→NumPy helper

The try/except + tolist() pattern is duplicated and mixes evaluation with conversion:

probs_mx = mlx_softmax(logits_pos)
mx.eval(logits_pos, probs_mx)
try:
    logits_np = np.array(logits_pos)
    probs = np.array(probs_mx)
except (RuntimeError, TypeError, ValueError):
    logits_np = np.array(logits_pos.tolist(), dtype=np.float32)
    probs = np.array(probs_mx.tolist(), dtype=np.float32)

Introduce a small helper that you can reuse everywhere:

def safe_mlx_to_numpy(x: mx.array, dtype=np.float32) -> np.ndarray:
    mx.eval(x)
    try:
        return np.array(x, dtype=dtype)
    except (RuntimeError, TypeError, ValueError):
        return np.array(x.tolist(), dtype=dtype)

Then simplify the call sites:

logits_pos = layer_logits[0, pos, :].astype(mx.float32)
probs_mx = mlx_softmax(logits_pos)

logits_np = safe_mlx_to_numpy(logits_pos)
probs = safe_mlx_to_numpy(probs_mx)

del logits_pos, probs_mx

And similarly in _capture_logit_lens:

logits_slice = logits[0, -1, :].astype(mx.float32)

if self.config.capture_token_probs:
    probs_mx = mlx_softmax(logits_slice)
    self.results.token_probs = safe_mlx_to_numpy(probs_mx)

self.results.logits = safe_mlx_to_numpy(logits_slice)

And in the full-sequence block:

pos_logits_mx = logits[0, pos, :].astype(mx.float32)
pos_probs_mx = mlx_softmax(pos_logits_mx)

pos_logits = safe_mlx_to_numpy(pos_logits_mx)
pos_probs = safe_mlx_to_numpy(pos_probs_mx)

This keeps behavior identical but removes repeated try/except and evaluation boilerplate.

2. Separate “what” from “how” in the probing logic

Right now _capture_logit_lens interleaves:

  • computing softmax on GPU,
  • evaluating MLX tensors,
  • converting to NumPy,
  • assigning to self.results.

You can isolate the GPU handling into a tiny utility and keep probing logic focused on results:

def compute_logits_and_probs(
    logits: mx.array, capture_probs: bool
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
    logits_slice = logits[0, -1, :].astype(mx.float32)
    if capture_probs:
        probs_mx = mlx_softmax(logits_slice)
        logits_np = safe_mlx_to_numpy(logits_slice)
        probs_np = safe_mlx_to_numpy(probs_mx)
        return logits_np, probs_np
    else:
        logits_np = safe_mlx_to_numpy(logits_slice)
        return logits_np, None

Usage in _capture_logit_lens:

self.results.logits, self.results.token_probs = compute_logits_and_probs(
    logits, self.config.capture_token_probs
)

The rest of the method can then work purely with NumPy and business logic.

3. Simplify plot_layer_similarity GPU path

You already have mlx_to_numpy and don’t need the extra staging via NumPy stacking:

vectors_mx = mx.stack(
    [mx.array(results.layer_outputs[idx].flatten()) for idx in layer_indices]
)
similarity_mx = mlx_cosine_similarity_matrix(vectors_mx)
similarity = mlx_to_numpy(similarity_mx)

This removes the explicit np.zeros initialization, the manual vectors list plus np.stack, and the separate mx.eval(similarity_mx) call (since mlx_to_numpy can handle evaluation if you fold that into it in the future).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant