Skip to content

Add GPU support: val primitives, CUDA and Metal backends#212

Open
PhilippGrulich wants to merge 11 commits intomainfrom
claude/plan-gpu-support-2qGrr
Open

Add GPU support: val primitives, CUDA and Metal backends#212
PhilippGrulich wants to merge 11 commits intomainfrom
claude/plan-gpu-support-2qGrr

Conversation

@PhilippGrulich
Copy link
Copy Markdown
Member

  • Add GPU intrinsic operations (THREAD_IDX, BLOCK_IDX, BLOCK_DIM, GRID_DIM,
    SYNC_THREADS, SHARED_ALLOCA) to tracing Op enum and IR OperationType enum
  • Add GPU IR operation classes (ThreadIdxOperation, BlockIdxOperation, etc.)
  • Extend TracingInterface with traceGPUIndex, traceSyncThreads, traceSharedAlloca
  • Implement GPU tracing in both ExceptionBasedTraceContext and LazyTraceContext
  • Add trace-to-IR conversion for all GPU operations
  • Create nautilus::gpu public API (gpu.hpp) with CPU fallback semantics:
    threadIdx/blockIdx return 0, blockDim/gridDim return 1, syncThreads is no-op
  • Add CUDA backend: generates .cu with global kernels, host wrapper with
    configurable grid/block dims via engine options, unified memory model
  • Add Metal backend: generates .metal MSL with kernel functions using
    [[thread_position_in_threadgroup]] etc. attributes, threadgroup shared memory
  • Register backends as ENABLE_CUDA_BACKEND and ENABLE_METAL_BACKEND cmake options
  • All 143 existing tests pass unchanged

https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3

Copy link
Copy Markdown
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracing Benchmark

Details
Benchmark suite Current: 317192f Previous: c01247b Ratio
comp_mlir_add 8.74752 ms (± 543.735) 8.25938 ms (± 184.393) 1.06
comp_mlir_ifThenElse 9.35877 ms (± 223.295) 8.83948 ms (± 177.957) 1.06
comp_mlir_deeplyNestedIfElse 8.42528 ms (± 550.802) 7.78595 ms (± 212.777) 1.08
comp_mlir_loop 10.6952 ms (± 246.131) 10.0365 ms (± 327.485) 1.07
comp_mlir_ifInsideLoop 33.3368 ms (± 488.962) 32.429 ms (± 343.986) 1.03
comp_mlir_loopDirectCall 15.4543 ms (± 330.561) 14.7347 ms (± 216.972) 1.05
comp_mlir_pointerLoop 32.296 ms (± 442.232) 31.13 ms (± 294.021) 1.04
comp_mlir_staticLoop 8.05071 ms (± 250.054) 7.63437 ms (± 172.018) 1.05
comp_mlir_fibonacci 14.6019 ms (± 327.979) 13.6997 ms (± 354.395) 1.07
comp_mlir_gcd 13.361 ms (± 457.576) 12.1509 ms (± 201.911) 1.10
comp_mlir_nestedIf10 14.1835 ms (± 368.486) 13.1972 ms (± 204.189) 1.07
comp_mlir_nestedIf100 29.3998 ms (± 774.024) 27.7343 ms (± 269.018) 1.06
comp_mlir_chainedIf10 13.9294 ms (± 284.066) 12.5453 ms (± 829.926) 1.11
comp_mlir_chainedIf100 24.7064 ms (± 413.48) 23.6239 ms (± 333.055) 1.05
comp_cpp_add 26.8487 ms (± 550.359) 25.866 ms (± 438.323) 1.04
comp_cpp_ifThenElse 27.3219 ms (± 655.689) 26.0439 ms (± 482.648) 1.05
comp_cpp_deeplyNestedIfElse 28.3365 ms (± 666.985) 27.2588 ms (± 426.676) 1.04
comp_cpp_loop 27.6935 ms (± 519.406) 26.1953 ms (± 282.785) 1.06
comp_cpp_ifInsideLoop 28.4682 ms (± 450.626) 27.2569 ms (± 493.711) 1.04
comp_cpp_loopDirectCall 27.8048 ms (± 487.759) 26.7647 ms (± 373.636) 1.04
comp_cpp_pointerLoop 27.5909 ms (± 520.776) 26.5196 ms (± 309.284) 1.04
comp_cpp_staticLoop 27.0831 ms (± 578.924) 26.3795 ms (± 521.078) 1.03
comp_cpp_fibonacci 27.7238 ms (± 460.827) 26.4091 ms (± 341.807) 1.05
comp_cpp_gcd 27.3537 ms (± 547.978) 26.2101 ms (± 367.548) 1.04
comp_cpp_nestedIf10 30.5591 ms (± 388.826) 29.4624 ms (± 394.264) 1.04
comp_cpp_nestedIf100 63.8057 ms (± 438.664) 62.783 ms (± 397.881) 1.02
comp_cpp_chainedIf10 32.8195 ms (± 1.20791) 31.7236 ms (± 618.486) 1.03
comp_cpp_chainedIf100 93.8645 ms (± 571.73) 92.5733 ms (± 532.692) 1.01
comp_bc_add 13.8126 us (± 2.42848) 14.2352 us (± 2.37181) 0.97
comp_bc_ifThenElse 16.8132 us (± 3.36083) 17.369 us (± 2.88438) 0.97
comp_bc_deeplyNestedIfElse 22.1246 us (± 3.47073) 22.1564 us (± 3.70898) 1.00
comp_bc_loop 16.7706 us (± 3.142) 17.8152 us (± 3.42614) 0.94
comp_bc_ifInsideLoop 20.6198 us (± 4.55374) 20.6383 us (± 4.07876) 1.00
comp_bc_loopDirectCall 18.309 us (± 3.42358) 18.9324 us (± 3.463) 0.97
comp_bc_pointerLoop 20.0771 us (± 4.33144) 19.8306 us (± 4.67596) 1.01
comp_bc_staticLoop 15.7806 us (± 2.56726) 17.2518 us (± 4.40826) 0.91
comp_bc_fibonacci 18.2577 us (± 3.88497) 18.3354 us (± 4.51067) 1.00
comp_bc_gcd 16.5204 us (± 2.3263) 17.6809 us (± 3.31926) 0.93
comp_bc_nestedIf10 36.0251 us (± 4.95856) 34.7376 us (± 4.99439) 1.04
comp_bc_nestedIf100 183.179 us (± 12.9148) 175.939 us (± 11.7728) 1.04
comp_bc_chainedIf10 50.5472 us (± 6.93489) 48.2046 us (± 7.65958) 1.05
comp_bc_chainedIf100 288.284 us (± 19.7275) 279.616 us (± 16.7199) 1.03
comp_asmjit_add 21.2227 us (± 4.08603) 21.3334 us (± 5.82024) 0.99
comp_asmjit_ifThenElse 34.8803 us (± 5.60113) 32.8085 us (± 4.55905) 1.06
comp_asmjit_deeplyNestedIfElse 60.4042 us (± 10.224) 58.1518 us (± 10.5121) 1.04
comp_asmjit_loop 36.7028 us (± 5.65698) 35.476 us (± 5.10508) 1.03
comp_asmjit_ifInsideLoop 60.6574 us (± 10.6946) 57.8955 us (± 9.42696) 1.05
comp_asmjit_loopDirectCall 48.8969 us (± 10.3338) 45.666 us (± 8.06252) 1.07
comp_asmjit_pointerLoop 51.5764 us (± 11.5863) 48.0129 us (± 8.49623) 1.07
comp_asmjit_staticLoop 29.2017 us (± 5.71962) 27.6994 us (± 4.60402) 1.05
comp_asmjit_fibonacci 64.4069 us (± 14.8416) 43.9836 us (± 8.35777) 1.46
comp_asmjit_gcd 36.1944 us (± 5.91345) 35.3579 us (± 6.11817) 1.02
comp_asmjit_nestedIf10 114.627 us (± 15.5419) 109.6 us (± 13.0658) 1.05
comp_asmjit_nestedIf100 1.16029 ms (± 27.3463) 1.12905 ms (± 22.2548) 1.03
comp_asmjit_chainedIf10 168.151 us (± 16.7987) 164.837 us (± 19.0884) 1.02
comp_asmjit_chainedIf100 2.33159 ms (± 42.2964) 2.27325 ms (± 47.0784) 1.03
trace_add 2.52911 us (± 271.632) 2.52067 us (± 245.618) 1.00
completing_trace_add 2.46474 us (± 317.69) 2.6538 us (± 358.325) 0.93
trace_ifThenElse 11.8689 us (± 1.90989) 11.8139 us (± 1.83892) 1.00
completing_trace_ifThenElse 5.39816 us (± 659.542) 5.51813 us (± 747.257) 0.98
trace_deeplyNestedIfElse 35.399 us (± 6.87071) 37.1472 us (± 6.76066) 0.95
completing_trace_deeplyNestedIfElse 15.5282 us (± 2.78343) 17.0804 us (± 3.53069) 0.91
trace_loop 11.415 us (± 1.64561) 11.7774 us (± 1.9669) 0.97
completing_trace_loop 5.24933 us (± 586.436) 5.66643 us (± 1.08795) 0.93
trace_ifInsideLoop 22.903 us (± 2.98762) 22.9594 us (± 3.28725) 1.00
completing_trace_ifInsideLoop 10.9452 us (± 2.17147) 10.6377 us (± 1.71834) 1.03
trace_loopDirectCall 11.4859 us (± 1.87757) 11.861 us (± 1.96453) 0.97
completing_trace_loopDirectCall 5.45584 us (± 747.752) 5.53426 us (± 799.207) 0.99
trace_pointerLoop 17.3933 us (± 3.43152) 18.179 us (± 3.86525) 0.96
completing_trace_pointerLoop 11.5549 us (± 1.71019) 11.9649 us (± 2.0581) 0.97
trace_staticLoop 9.68877 us (± 1.4669) 9.83869 us (± 1.37239) 0.98
completing_trace_staticLoop 9.2314 us (± 1.30452) 9.32965 us (± 1.35495) 0.99
trace_fibonacci 12.9496 us (± 2.21381) 13.0407 us (± 1.88404) 0.99
completing_trace_fibonacci 6.766 us (± 870.805) 7.10904 us (± 961.419) 0.95
trace_gcd 11.1131 us (± 2.0464) 10.7828 us (± 1.7374) 1.03
completing_trace_gcd 4.51661 us (± 543.786) 4.64365 us (± 522.599) 0.97
trace_nestedIf10 56.1683 us (± 7.75532) 57.6997 us (± 7.73545) 0.97
completing_trace_nestedIf10 55.6981 us (± 8.9935) 58.0938 us (± 8.82738) 0.96
trace_nestedIf100 1.77001 ms (± 38.234) 1.80438 ms (± 39.6808) 0.98
completing_trace_nestedIf100 1.75542 ms (± 54.7853) 1.81009 ms (± 37.214) 0.97
trace_chainedIf10 138.339 us (± 11.8288) 140.21 us (± 12.5902) 0.99
completing_trace_chainedIf10 76.2653 us (± 16.851) 71.212 us (± 9.26146) 1.07
trace_chainedIf100 5.18013 ms (± 79.54) 5.15008 ms (± 37.6043) 1.01
completing_trace_chainedIf100 2.78133 ms (± 35.3949) 2.86089 ms (± 51.0819) 0.97
ssa_add 192.054 ns (± 11.6188) 206.987 ns (± 29.7624) 0.93
ssa_ifThenElse 469.173 ns (± 32.8491) 507.968 ns (± 67.906) 0.92
ssa_deeplyNestedIfElse 1.16322 us (± 96.6976) 1.2589 us (± 139.2) 0.92
ssa_loop 495.645 ns (± 32.3777) 512.462 ns (± 42.9444) 0.97
ssa_ifInsideLoop 946.469 ns (± 140.644) 978.605 ns (± 90.4827) 0.97
ssa_loopDirectCall 502.414 ns (± 42.2394) 528.28 ns (± 61.6351) 0.95
ssa_pointerLoop 614.685 ns (± 82.3495) 636.582 ns (± 92.2697) 0.97
ssa_staticLoop 517.374 ns (± 43.2231) 498.308 ns (± 51.1977) 1.04
ssa_fibonacci 513.49 ns (± 30.9228) 537.388 ns (± 52.4788) 0.96
ssa_gcd 472.464 ns (± 44.6959) 490.895 ns (± 82.9043) 0.96
e2e_tiered_bc_to_mlir 43.5856 us (± 16.7052) 37.5093 us (± 11.7494) 1.16
e2e_single_mlir 8.44759 ms (± 195.58) 8.19571 ms (± 167.513) 1.03
exec_mlir_add 9.84863 ns (± 0.827914) 10.008 ns (± 1.2622) 0.98
exec_mlir_fibonacci 14.5198 us (± 1.58884) 14.7086 us (± 2.13282) 0.99
exec_mlir_sum 621.367 us (± 28.6584) 561.635 us (± 24.7226) 1.11
exec_cpp_add 4.68436 ns (± 0.633448) 4.70719 ns (± 0.588182) 1.00
exec_cpp_fibonacci 96.6179 us (± 7.66004) 94.611 us (± 3.80368) 1.02
exec_cpp_sum 35.9607 ms (± 84.1939) 36.1303 ms (± 533.696) 1.00
exec_bc_add 43.7688 ns (± 6.27891) 55.5753 ns (± 15.5573) 0.79
exec_bc_fibonacci 929.169 us (± 11.4064) 932.445 us (± 28.0037) 1.00
exec_bc_sum 199.704 ms (± 190.562) 196.151 ms (± 1.34432) 1.02
exec_asmjit_add 3.25985 ns (± 0.454466) 3.23032 ns (± 0.307965) 1.01
exec_asmjit_fibonacci 21.3528 us (± 1.43005) 21.5547 us (± 2.23995) 0.99
exec_asmjit_sum 4.61414 ms (± 56.4904) 4.61075 ms (± 40.9482) 1.00
tiered_compile_addOne 42.2736 us (± 10.6798) 42.5661 us (± 12.2529) 0.99
single_compile_mlir_addOne 6.40473 ms (± 162.561) 6.56363 ms (± 204.754) 0.98
single_compile_cpp_addOne 26.3563 ms (± 620.501) 25.3529 ms (± 376.871) 1.04
single_compile_bc_addOne 42.6625 us (± 12.8758) 42.2573 us (± 11.3992) 1.01
tiered_compile_sumLoop 61.9982 us (± 14.6537) 61.9797 us (± 12.8559) 1.00
single_compile_mlir_sumLoop 8.73637 ms (± 252.431) 8.3116 ms (± 164.895) 1.05
single_compile_cpp_sumLoop 27.5291 ms (± 601.029) 26.3052 ms (± 333.381) 1.05
single_compile_bc_sumLoop 62.3924 us (± 15.4477) 60.6621 us (± 13.2576) 1.03
ir_add 840.871 ns (± 74.5984) 915.956 ns (± 119.273) 0.92
ir_ifThenElse 2.58201 us (± 255.862) 2.5734 us (± 292.207) 1.00
ir_deeplyNestedIfElse 6.90664 us (± 895.681) 6.87058 us (± 634.436) 1.01
ir_loop 3.05842 us (± 330.475) 3.06223 us (± 361.375) 1.00
ir_ifInsideLoop 5.812 us (± 365.581) 5.91892 us (± 738.714) 0.98
ir_loopDirectCall 3.28239 us (± 278.58) 3.30114 us (± 364.361) 0.99
ir_pointerLoop 4.03816 us (± 350.207) 3.90241 us (± 328.825) 1.03
ir_staticLoop 2.26288 us (± 207.731) 2.30728 us (± 272.394) 0.98
ir_fibonacci 3.25276 us (± 266.194) 3.25706 us (± 357.786) 1.00
ir_gcd 2.72281 us (± 200.911) 2.75793 us (± 297.114) 0.99
ir_nestedIf10 15.406 us (± 989.135) 16.4979 us (± 1.85667) 0.93
ir_nestedIf100 188.837 us (± 8.78571) 196.116 us (± 16.1689) 0.96
ir_chainedIf10 28.9677 us (± 2.65582) 29.2126 us (± 2.07206) 0.99
ir_chainedIf100 356.403 us (± 15.1721) 372.703 us (± 10.9432) 0.96
exec_bc_addOne 35.352 ns (± 4.73231) 38.3849 ns (± 9.06161) 0.92
exec_mlir_addOne 290.147 ns (± 9.09905) 282.268 ns (± 7.59911) 1.03
exec_cpp_addOne 4.06012 ns (± 0.53483) 4.05985 ns (± 0.739266) 1.00
exec_interpreted_addOne 39.0327 ns (± 2.19069) 39.3263 ns (± 2.05037) 0.99

This comment was automatically generated by workflow using github-action-benchmark.

@PhilippGrulich PhilippGrulich force-pushed the claude/plan-gpu-support-2qGrr branch 2 times, most recently from 416958b to 9efbbdd Compare April 3, 2026 07:35
claude added 10 commits April 5, 2026 22:06
- Add GPU intrinsic operations (THREAD_IDX, BLOCK_IDX, BLOCK_DIM, GRID_DIM,
  SYNC_THREADS, SHARED_ALLOCA) to tracing Op enum and IR OperationType enum
- Add GPU IR operation classes (ThreadIdxOperation, BlockIdxOperation, etc.)
- Extend TracingInterface with traceGPUIndex, traceSyncThreads, traceSharedAlloca
- Implement GPU tracing in both ExceptionBasedTraceContext and LazyTraceContext
- Add trace-to-IR conversion for all GPU operations
- Create nautilus::gpu public API (gpu.hpp) with CPU fallback semantics:
  threadIdx/blockIdx return 0, blockDim/gridDim return 1, syncThreads is no-op
- Add CUDA backend: generates .cu with __global__ kernels, host wrapper with
  configurable grid/block dims via engine options, unified memory model
- Add Metal backend: generates .metal MSL with kernel functions using
  [[thread_position_in_threadgroup]] etc. attributes, threadgroup shared memory
- Register backends as ENABLE_CUDA_BACKEND and ENABLE_METAL_BACKEND cmake options
- All 143 existing tests pass unchanged

https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
Replace custom GPU tracing ops and IR operations with the invoke()
pattern used by existing intrinsics (math, bit, memory). GPU functions
are now plain extern "C" functions that:
- Provide CPU fallback values (threadIdx=0, blockDim=1, etc.)
- Have stable function pointers for intrinsic matching
- Flow through existing CALL → ProxyCallOperation pipeline

GPU backends recognize ProxyCallOperations by function pointer and
replace them with target-specific intrinsics:
- CUDA: threadIdx.x, blockIdx.x, __syncthreads(), etc.
- Metal: nautilus_threadIdx.x via [[thread_position_in_threadgroup]], etc.

No changes to tracing, IR, or TracingInterface needed.

https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
- GPUFunctions.hpp: test functions using gpu::threadIdx_x/blockIdx_x/etc.
  including global thread index computation, vector add, and sync patterns
- GPUExecutionTest.cpp: interpreter + compiler backend tests verifying
  CPU fallback semantics (threadIdx=0, blockDim=1, etc.)
- TracingTest.cpp: GPU tracing tests comparing against reference traces
  in test/data/gpu-tests/{tracing,after_ssa,ir}/
- GPUCodegenTest.cpp: generates CUDA and Metal source via the lowering
  providers and compares against reference files in test/data/gpu-tests/
  {cuda,metal}/ - verifies intrinsic replacement produces correct
  threadIdx.x, __syncthreads(), [[thread_position_in_threadgroup]], etc.

https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
Design: Kernels are defined as NautilusFunction instances. The host
function calls gpu::launch(kernel, config, args...) which on CPU simply
invokes the kernel once (single-thread fallback). GPU backends detect
which NautilusFunction bodies use GPU intrinsics (threadIdx, etc.) by
scanning their IR for matching ProxyCallOperations, then:

- CUDA: emits __global__ for kernels, <<<grid,block>>> at call sites
- Metal: emits kernel void with [[thread_position_in_threadgroup]] etc.

Host functions that call kernels become extern "C" (CUDA) or plain
functions (Metal). Non-kernel internal functions become __device__.

Example usage:
  static auto myKernel = NautilusFunction{"vecAdd",
      [](val<float*> a, val<float*> b, val<float*> c) {
          auto tid = gpu::threadIdx_x();
          c[tid] = a[tid] + b[tid];
      }};
  // In traced function:
  gpu::launch(myKernel, config, a, b, c);

https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
Bug 1 (Critical): Fix AddOp using processBinary<ir::AndOperation> instead
of processBinary<ir::AddOperation> in CPP, CUDA, and Metal backends.
The wrong static_cast was undefined behavior (happened to work due to
identical BinaryOperation memory layout).

Bug 2 (High): Fix CUDA name collision where root kernel and host wrapper
were both named "execute". Kernel now gets "_kernel" suffix:
  __global__ void execute_kernel(...) { ... }
  extern "C" void execute(...) { execute_kernel<<<grid,block>>>(...); }

Bug 3 (Design): Remove LaunchConfig struct from gpu::launch() API since
it was silently ignored. Grid/block dims are configured via
engine::Options ("gpu.gridDimX", "gpu.blockDimX", etc.).

Bug 4 (High): Fix Metal backend emitting invalid MSL where a non-kernel
function called a kernel function. MSL kernel functions are entry points
only. Now non-kernel root functions (host code) are skipped in MSL
output - only kernel functions are emitted in the .metal file.

https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
Fix missing space before '=' in negate and not operations across CPP,
CUDA, and Metal backends (generated 'var= ~x' instead of 'var = ~x').

Fix GPU codegen test flakiness by finding the most recently written
dump file instead of deleting the shared /tmp/dump directory which
races with other parallel tests.

https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
Move all GPU code (API, CUDA/Metal backends, tests) from the core
nautilus library into a separate plugin module at plugins/gpu/.

Plugin architecture:
- plugins/gpu/ builds as its own library (nautilus-gpu) linked to nautilus
- GPU backends register dynamically via CompilationBackendRegistry::registerBackend()
  using a static initializer (GPUBackendRegistration.cpp)
- Plugin has its own test executables and test data
- Enabled via -DENABLE_GPU_PLUGIN=ON (off by default)

Core library changes:
- CompilationBackendRegistry gains public registerBackend() method
- getInstance() returns non-const pointer for plugin registration
- All GPU #ifdefs, includes, and registrations removed from core
- Core library tests pass unchanged (143/143)

Plugin structure:
  plugins/gpu/
  ├── include/nautilus/       (gpu.hpp, gpu_intrinsic_targets.hpp)
  ├── src/                    (gpu.cpp, GPUBackendRegistration.cpp)
  │   ├── cuda/               (CUDA backend)
  │   └── metal/              (Metal backend)
  └── test/                   (5 test executables + reference data)

https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
Fix 1: Extract shared lowering code into GPULoweringProviderBase (CRTP).
Common operation dispatch, block processing, and arithmetic/logic/memory
handlers are shared. CUDA provider: 685→310 lines. Metal: 536→201 lines.
Base class: 473 lines. Net 30% reduction in total code.

Fix 2: Delete orphaned nautilus/test/data/gpu-tests/ (moved to plugin).

Fix 3: Add gpu_backends.hpp with lowerToCUDA()/lowerToMetal() public API.
GPUCodegenTest now calls lowering providers directly instead of fragile
dump-file-scanning via JITCompiler. Eliminates parallel test flakiness.

Fix 4: Add whole-archive linker flag to prevent dead-stripping of the
static GPUBackendRegistration initializer.

Fix 5: Add install() target for nautilus-gpu library and headers.

Fix 6: Improve gpu::launch() documentation explaining the full mechanism,
how grid/block dimensions are configured via engine::Options, and the
CPU fallback semantics.

Fix 7: format.sh already covers plugins/ via git ls-files. No change needed.

https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
gpu::launch() now takes GridDim and BlockDim as traced val<uint32_t>
values, enabling dynamic launch configuration computed from data:

  auto blocks = (n + val<uint32_t>(255)) / val<uint32_t>(256);
  gpu::launch(kernel, gpu::GridDim{blocks}, gpu::BlockDim{256}, args...);

Implementation:
- gpu::launch() calls setGrid(x,y,z) and setBlock(x,y,z) before the
  kernel call. These are extern "C" functions (no-op on CPU) that trace
  as ProxyCallOperations via invoke().
- GPU backends register setGrid/setBlock as intrinsics that capture the
  argument variable names into pending launch config state.
- When the backend encounters the next kernel call, it uses the captured
  variable names in the launch syntax instead of static Options values.
- Falls back to engine::Options if no setGrid/setBlock precedes the call.

CUDA generates: vecAdd<<<dim3(var_$8,1,1),dim3(256,1,1)>>>(args)
where var_$8 = (n + 255) / 256 is computed at runtime.

Also fixes GPU_TEST_DATA_FOLDER macro collision with core TEST_DATA_FOLDER.

https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
… intrinsics

setGrid/setBlock intrinsics were incorrectly causing host functions that
call gpu::launch() to be classified as kernels. Now only device intrinsics
(threadIdx, blockIdx, syncThreads, etc.) mark a function as a kernel.
Launch config intrinsics are host-side and handled separately.

Fixes Metal output for gpuLaunchVecAddDynamic: the host function that
computes grid dims and launches the kernel is now correctly omitted from
MSL (only the kernel function is emitted).

https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
@PhilippGrulich PhilippGrulich force-pushed the claude/plan-gpu-support-2qGrr branch 2 times, most recently from 6ff4c57 to 03ca011 Compare April 6, 2026 07:45
The Metal backend now produces two outputs:
- .metal: MSL kernel source (device code, as before)
- .cpp: C++ host source with Objective-C Metal API dispatch

The host code uses Metal API to:
1. Create MTLDevice, MTLLibrary, MTLComputePipelineState
2. Set buffer arguments via [encoder setBytes:...]
3. Dispatch with dynamic grid/block dims from setGrid/setBlock
4. Wait for completion via [cmdBuf waitUntilCompleted]

The MetalCompilationBackend compiles the host .cpp via CPPCompiler
(the .metal is compiled separately via xcrun metal on macOS).

API change: lowerToMetal() now returns MetalOutput{deviceCode, hostCode}
instead of a single string. Codegen tests verify both outputs.

https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
@PhilippGrulich PhilippGrulich force-pushed the claude/plan-gpu-support-2qGrr branch from 03ca011 to 317192f Compare April 6, 2026 09:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants