Add GPU support: val primitives, CUDA and Metal backends#212
Open
PhilippGrulich wants to merge 11 commits intomainfrom
Open
Add GPU support: val primitives, CUDA and Metal backends#212PhilippGrulich wants to merge 11 commits intomainfrom
PhilippGrulich wants to merge 11 commits intomainfrom
Conversation
Contributor
There was a problem hiding this comment.
Tracing Benchmark
Details
| Benchmark suite | Current: 317192f | Previous: c01247b | Ratio |
|---|---|---|---|
comp_mlir_add |
8.74752 ms (± 543.735) |
8.25938 ms (± 184.393) |
1.06 |
comp_mlir_ifThenElse |
9.35877 ms (± 223.295) |
8.83948 ms (± 177.957) |
1.06 |
comp_mlir_deeplyNestedIfElse |
8.42528 ms (± 550.802) |
7.78595 ms (± 212.777) |
1.08 |
comp_mlir_loop |
10.6952 ms (± 246.131) |
10.0365 ms (± 327.485) |
1.07 |
comp_mlir_ifInsideLoop |
33.3368 ms (± 488.962) |
32.429 ms (± 343.986) |
1.03 |
comp_mlir_loopDirectCall |
15.4543 ms (± 330.561) |
14.7347 ms (± 216.972) |
1.05 |
comp_mlir_pointerLoop |
32.296 ms (± 442.232) |
31.13 ms (± 294.021) |
1.04 |
comp_mlir_staticLoop |
8.05071 ms (± 250.054) |
7.63437 ms (± 172.018) |
1.05 |
comp_mlir_fibonacci |
14.6019 ms (± 327.979) |
13.6997 ms (± 354.395) |
1.07 |
comp_mlir_gcd |
13.361 ms (± 457.576) |
12.1509 ms (± 201.911) |
1.10 |
comp_mlir_nestedIf10 |
14.1835 ms (± 368.486) |
13.1972 ms (± 204.189) |
1.07 |
comp_mlir_nestedIf100 |
29.3998 ms (± 774.024) |
27.7343 ms (± 269.018) |
1.06 |
comp_mlir_chainedIf10 |
13.9294 ms (± 284.066) |
12.5453 ms (± 829.926) |
1.11 |
comp_mlir_chainedIf100 |
24.7064 ms (± 413.48) |
23.6239 ms (± 333.055) |
1.05 |
comp_cpp_add |
26.8487 ms (± 550.359) |
25.866 ms (± 438.323) |
1.04 |
comp_cpp_ifThenElse |
27.3219 ms (± 655.689) |
26.0439 ms (± 482.648) |
1.05 |
comp_cpp_deeplyNestedIfElse |
28.3365 ms (± 666.985) |
27.2588 ms (± 426.676) |
1.04 |
comp_cpp_loop |
27.6935 ms (± 519.406) |
26.1953 ms (± 282.785) |
1.06 |
comp_cpp_ifInsideLoop |
28.4682 ms (± 450.626) |
27.2569 ms (± 493.711) |
1.04 |
comp_cpp_loopDirectCall |
27.8048 ms (± 487.759) |
26.7647 ms (± 373.636) |
1.04 |
comp_cpp_pointerLoop |
27.5909 ms (± 520.776) |
26.5196 ms (± 309.284) |
1.04 |
comp_cpp_staticLoop |
27.0831 ms (± 578.924) |
26.3795 ms (± 521.078) |
1.03 |
comp_cpp_fibonacci |
27.7238 ms (± 460.827) |
26.4091 ms (± 341.807) |
1.05 |
comp_cpp_gcd |
27.3537 ms (± 547.978) |
26.2101 ms (± 367.548) |
1.04 |
comp_cpp_nestedIf10 |
30.5591 ms (± 388.826) |
29.4624 ms (± 394.264) |
1.04 |
comp_cpp_nestedIf100 |
63.8057 ms (± 438.664) |
62.783 ms (± 397.881) |
1.02 |
comp_cpp_chainedIf10 |
32.8195 ms (± 1.20791) |
31.7236 ms (± 618.486) |
1.03 |
comp_cpp_chainedIf100 |
93.8645 ms (± 571.73) |
92.5733 ms (± 532.692) |
1.01 |
comp_bc_add |
13.8126 us (± 2.42848) |
14.2352 us (± 2.37181) |
0.97 |
comp_bc_ifThenElse |
16.8132 us (± 3.36083) |
17.369 us (± 2.88438) |
0.97 |
comp_bc_deeplyNestedIfElse |
22.1246 us (± 3.47073) |
22.1564 us (± 3.70898) |
1.00 |
comp_bc_loop |
16.7706 us (± 3.142) |
17.8152 us (± 3.42614) |
0.94 |
comp_bc_ifInsideLoop |
20.6198 us (± 4.55374) |
20.6383 us (± 4.07876) |
1.00 |
comp_bc_loopDirectCall |
18.309 us (± 3.42358) |
18.9324 us (± 3.463) |
0.97 |
comp_bc_pointerLoop |
20.0771 us (± 4.33144) |
19.8306 us (± 4.67596) |
1.01 |
comp_bc_staticLoop |
15.7806 us (± 2.56726) |
17.2518 us (± 4.40826) |
0.91 |
comp_bc_fibonacci |
18.2577 us (± 3.88497) |
18.3354 us (± 4.51067) |
1.00 |
comp_bc_gcd |
16.5204 us (± 2.3263) |
17.6809 us (± 3.31926) |
0.93 |
comp_bc_nestedIf10 |
36.0251 us (± 4.95856) |
34.7376 us (± 4.99439) |
1.04 |
comp_bc_nestedIf100 |
183.179 us (± 12.9148) |
175.939 us (± 11.7728) |
1.04 |
comp_bc_chainedIf10 |
50.5472 us (± 6.93489) |
48.2046 us (± 7.65958) |
1.05 |
comp_bc_chainedIf100 |
288.284 us (± 19.7275) |
279.616 us (± 16.7199) |
1.03 |
comp_asmjit_add |
21.2227 us (± 4.08603) |
21.3334 us (± 5.82024) |
0.99 |
comp_asmjit_ifThenElse |
34.8803 us (± 5.60113) |
32.8085 us (± 4.55905) |
1.06 |
comp_asmjit_deeplyNestedIfElse |
60.4042 us (± 10.224) |
58.1518 us (± 10.5121) |
1.04 |
comp_asmjit_loop |
36.7028 us (± 5.65698) |
35.476 us (± 5.10508) |
1.03 |
comp_asmjit_ifInsideLoop |
60.6574 us (± 10.6946) |
57.8955 us (± 9.42696) |
1.05 |
comp_asmjit_loopDirectCall |
48.8969 us (± 10.3338) |
45.666 us (± 8.06252) |
1.07 |
comp_asmjit_pointerLoop |
51.5764 us (± 11.5863) |
48.0129 us (± 8.49623) |
1.07 |
comp_asmjit_staticLoop |
29.2017 us (± 5.71962) |
27.6994 us (± 4.60402) |
1.05 |
comp_asmjit_fibonacci |
64.4069 us (± 14.8416) |
43.9836 us (± 8.35777) |
1.46 |
comp_asmjit_gcd |
36.1944 us (± 5.91345) |
35.3579 us (± 6.11817) |
1.02 |
comp_asmjit_nestedIf10 |
114.627 us (± 15.5419) |
109.6 us (± 13.0658) |
1.05 |
comp_asmjit_nestedIf100 |
1.16029 ms (± 27.3463) |
1.12905 ms (± 22.2548) |
1.03 |
comp_asmjit_chainedIf10 |
168.151 us (± 16.7987) |
164.837 us (± 19.0884) |
1.02 |
comp_asmjit_chainedIf100 |
2.33159 ms (± 42.2964) |
2.27325 ms (± 47.0784) |
1.03 |
trace_add |
2.52911 us (± 271.632) |
2.52067 us (± 245.618) |
1.00 |
completing_trace_add |
2.46474 us (± 317.69) |
2.6538 us (± 358.325) |
0.93 |
trace_ifThenElse |
11.8689 us (± 1.90989) |
11.8139 us (± 1.83892) |
1.00 |
completing_trace_ifThenElse |
5.39816 us (± 659.542) |
5.51813 us (± 747.257) |
0.98 |
trace_deeplyNestedIfElse |
35.399 us (± 6.87071) |
37.1472 us (± 6.76066) |
0.95 |
completing_trace_deeplyNestedIfElse |
15.5282 us (± 2.78343) |
17.0804 us (± 3.53069) |
0.91 |
trace_loop |
11.415 us (± 1.64561) |
11.7774 us (± 1.9669) |
0.97 |
completing_trace_loop |
5.24933 us (± 586.436) |
5.66643 us (± 1.08795) |
0.93 |
trace_ifInsideLoop |
22.903 us (± 2.98762) |
22.9594 us (± 3.28725) |
1.00 |
completing_trace_ifInsideLoop |
10.9452 us (± 2.17147) |
10.6377 us (± 1.71834) |
1.03 |
trace_loopDirectCall |
11.4859 us (± 1.87757) |
11.861 us (± 1.96453) |
0.97 |
completing_trace_loopDirectCall |
5.45584 us (± 747.752) |
5.53426 us (± 799.207) |
0.99 |
trace_pointerLoop |
17.3933 us (± 3.43152) |
18.179 us (± 3.86525) |
0.96 |
completing_trace_pointerLoop |
11.5549 us (± 1.71019) |
11.9649 us (± 2.0581) |
0.97 |
trace_staticLoop |
9.68877 us (± 1.4669) |
9.83869 us (± 1.37239) |
0.98 |
completing_trace_staticLoop |
9.2314 us (± 1.30452) |
9.32965 us (± 1.35495) |
0.99 |
trace_fibonacci |
12.9496 us (± 2.21381) |
13.0407 us (± 1.88404) |
0.99 |
completing_trace_fibonacci |
6.766 us (± 870.805) |
7.10904 us (± 961.419) |
0.95 |
trace_gcd |
11.1131 us (± 2.0464) |
10.7828 us (± 1.7374) |
1.03 |
completing_trace_gcd |
4.51661 us (± 543.786) |
4.64365 us (± 522.599) |
0.97 |
trace_nestedIf10 |
56.1683 us (± 7.75532) |
57.6997 us (± 7.73545) |
0.97 |
completing_trace_nestedIf10 |
55.6981 us (± 8.9935) |
58.0938 us (± 8.82738) |
0.96 |
trace_nestedIf100 |
1.77001 ms (± 38.234) |
1.80438 ms (± 39.6808) |
0.98 |
completing_trace_nestedIf100 |
1.75542 ms (± 54.7853) |
1.81009 ms (± 37.214) |
0.97 |
trace_chainedIf10 |
138.339 us (± 11.8288) |
140.21 us (± 12.5902) |
0.99 |
completing_trace_chainedIf10 |
76.2653 us (± 16.851) |
71.212 us (± 9.26146) |
1.07 |
trace_chainedIf100 |
5.18013 ms (± 79.54) |
5.15008 ms (± 37.6043) |
1.01 |
completing_trace_chainedIf100 |
2.78133 ms (± 35.3949) |
2.86089 ms (± 51.0819) |
0.97 |
ssa_add |
192.054 ns (± 11.6188) |
206.987 ns (± 29.7624) |
0.93 |
ssa_ifThenElse |
469.173 ns (± 32.8491) |
507.968 ns (± 67.906) |
0.92 |
ssa_deeplyNestedIfElse |
1.16322 us (± 96.6976) |
1.2589 us (± 139.2) |
0.92 |
ssa_loop |
495.645 ns (± 32.3777) |
512.462 ns (± 42.9444) |
0.97 |
ssa_ifInsideLoop |
946.469 ns (± 140.644) |
978.605 ns (± 90.4827) |
0.97 |
ssa_loopDirectCall |
502.414 ns (± 42.2394) |
528.28 ns (± 61.6351) |
0.95 |
ssa_pointerLoop |
614.685 ns (± 82.3495) |
636.582 ns (± 92.2697) |
0.97 |
ssa_staticLoop |
517.374 ns (± 43.2231) |
498.308 ns (± 51.1977) |
1.04 |
ssa_fibonacci |
513.49 ns (± 30.9228) |
537.388 ns (± 52.4788) |
0.96 |
ssa_gcd |
472.464 ns (± 44.6959) |
490.895 ns (± 82.9043) |
0.96 |
e2e_tiered_bc_to_mlir |
43.5856 us (± 16.7052) |
37.5093 us (± 11.7494) |
1.16 |
e2e_single_mlir |
8.44759 ms (± 195.58) |
8.19571 ms (± 167.513) |
1.03 |
exec_mlir_add |
9.84863 ns (± 0.827914) |
10.008 ns (± 1.2622) |
0.98 |
exec_mlir_fibonacci |
14.5198 us (± 1.58884) |
14.7086 us (± 2.13282) |
0.99 |
exec_mlir_sum |
621.367 us (± 28.6584) |
561.635 us (± 24.7226) |
1.11 |
exec_cpp_add |
4.68436 ns (± 0.633448) |
4.70719 ns (± 0.588182) |
1.00 |
exec_cpp_fibonacci |
96.6179 us (± 7.66004) |
94.611 us (± 3.80368) |
1.02 |
exec_cpp_sum |
35.9607 ms (± 84.1939) |
36.1303 ms (± 533.696) |
1.00 |
exec_bc_add |
43.7688 ns (± 6.27891) |
55.5753 ns (± 15.5573) |
0.79 |
exec_bc_fibonacci |
929.169 us (± 11.4064) |
932.445 us (± 28.0037) |
1.00 |
exec_bc_sum |
199.704 ms (± 190.562) |
196.151 ms (± 1.34432) |
1.02 |
exec_asmjit_add |
3.25985 ns (± 0.454466) |
3.23032 ns (± 0.307965) |
1.01 |
exec_asmjit_fibonacci |
21.3528 us (± 1.43005) |
21.5547 us (± 2.23995) |
0.99 |
exec_asmjit_sum |
4.61414 ms (± 56.4904) |
4.61075 ms (± 40.9482) |
1.00 |
tiered_compile_addOne |
42.2736 us (± 10.6798) |
42.5661 us (± 12.2529) |
0.99 |
single_compile_mlir_addOne |
6.40473 ms (± 162.561) |
6.56363 ms (± 204.754) |
0.98 |
single_compile_cpp_addOne |
26.3563 ms (± 620.501) |
25.3529 ms (± 376.871) |
1.04 |
single_compile_bc_addOne |
42.6625 us (± 12.8758) |
42.2573 us (± 11.3992) |
1.01 |
tiered_compile_sumLoop |
61.9982 us (± 14.6537) |
61.9797 us (± 12.8559) |
1.00 |
single_compile_mlir_sumLoop |
8.73637 ms (± 252.431) |
8.3116 ms (± 164.895) |
1.05 |
single_compile_cpp_sumLoop |
27.5291 ms (± 601.029) |
26.3052 ms (± 333.381) |
1.05 |
single_compile_bc_sumLoop |
62.3924 us (± 15.4477) |
60.6621 us (± 13.2576) |
1.03 |
ir_add |
840.871 ns (± 74.5984) |
915.956 ns (± 119.273) |
0.92 |
ir_ifThenElse |
2.58201 us (± 255.862) |
2.5734 us (± 292.207) |
1.00 |
ir_deeplyNestedIfElse |
6.90664 us (± 895.681) |
6.87058 us (± 634.436) |
1.01 |
ir_loop |
3.05842 us (± 330.475) |
3.06223 us (± 361.375) |
1.00 |
ir_ifInsideLoop |
5.812 us (± 365.581) |
5.91892 us (± 738.714) |
0.98 |
ir_loopDirectCall |
3.28239 us (± 278.58) |
3.30114 us (± 364.361) |
0.99 |
ir_pointerLoop |
4.03816 us (± 350.207) |
3.90241 us (± 328.825) |
1.03 |
ir_staticLoop |
2.26288 us (± 207.731) |
2.30728 us (± 272.394) |
0.98 |
ir_fibonacci |
3.25276 us (± 266.194) |
3.25706 us (± 357.786) |
1.00 |
ir_gcd |
2.72281 us (± 200.911) |
2.75793 us (± 297.114) |
0.99 |
ir_nestedIf10 |
15.406 us (± 989.135) |
16.4979 us (± 1.85667) |
0.93 |
ir_nestedIf100 |
188.837 us (± 8.78571) |
196.116 us (± 16.1689) |
0.96 |
ir_chainedIf10 |
28.9677 us (± 2.65582) |
29.2126 us (± 2.07206) |
0.99 |
ir_chainedIf100 |
356.403 us (± 15.1721) |
372.703 us (± 10.9432) |
0.96 |
exec_bc_addOne |
35.352 ns (± 4.73231) |
38.3849 ns (± 9.06161) |
0.92 |
exec_mlir_addOne |
290.147 ns (± 9.09905) |
282.268 ns (± 7.59911) |
1.03 |
exec_cpp_addOne |
4.06012 ns (± 0.53483) |
4.05985 ns (± 0.739266) |
1.00 |
exec_interpreted_addOne |
39.0327 ns (± 2.19069) |
39.3263 ns (± 2.05037) |
0.99 |
This comment was automatically generated by workflow using github-action-benchmark.
416958b to
9efbbdd
Compare
- Add GPU intrinsic operations (THREAD_IDX, BLOCK_IDX, BLOCK_DIM, GRID_DIM, SYNC_THREADS, SHARED_ALLOCA) to tracing Op enum and IR OperationType enum - Add GPU IR operation classes (ThreadIdxOperation, BlockIdxOperation, etc.) - Extend TracingInterface with traceGPUIndex, traceSyncThreads, traceSharedAlloca - Implement GPU tracing in both ExceptionBasedTraceContext and LazyTraceContext - Add trace-to-IR conversion for all GPU operations - Create nautilus::gpu public API (gpu.hpp) with CPU fallback semantics: threadIdx/blockIdx return 0, blockDim/gridDim return 1, syncThreads is no-op - Add CUDA backend: generates .cu with __global__ kernels, host wrapper with configurable grid/block dims via engine options, unified memory model - Add Metal backend: generates .metal MSL with kernel functions using [[thread_position_in_threadgroup]] etc. attributes, threadgroup shared memory - Register backends as ENABLE_CUDA_BACKEND and ENABLE_METAL_BACKEND cmake options - All 143 existing tests pass unchanged https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
Replace custom GPU tracing ops and IR operations with the invoke() pattern used by existing intrinsics (math, bit, memory). GPU functions are now plain extern "C" functions that: - Provide CPU fallback values (threadIdx=0, blockDim=1, etc.) - Have stable function pointers for intrinsic matching - Flow through existing CALL → ProxyCallOperation pipeline GPU backends recognize ProxyCallOperations by function pointer and replace them with target-specific intrinsics: - CUDA: threadIdx.x, blockIdx.x, __syncthreads(), etc. - Metal: nautilus_threadIdx.x via [[thread_position_in_threadgroup]], etc. No changes to tracing, IR, or TracingInterface needed. https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
- GPUFunctions.hpp: test functions using gpu::threadIdx_x/blockIdx_x/etc.
including global thread index computation, vector add, and sync patterns
- GPUExecutionTest.cpp: interpreter + compiler backend tests verifying
CPU fallback semantics (threadIdx=0, blockDim=1, etc.)
- TracingTest.cpp: GPU tracing tests comparing against reference traces
in test/data/gpu-tests/{tracing,after_ssa,ir}/
- GPUCodegenTest.cpp: generates CUDA and Metal source via the lowering
providers and compares against reference files in test/data/gpu-tests/
{cuda,metal}/ - verifies intrinsic replacement produces correct
threadIdx.x, __syncthreads(), [[thread_position_in_threadgroup]], etc.
https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
Design: Kernels are defined as NautilusFunction instances. The host
function calls gpu::launch(kernel, config, args...) which on CPU simply
invokes the kernel once (single-thread fallback). GPU backends detect
which NautilusFunction bodies use GPU intrinsics (threadIdx, etc.) by
scanning their IR for matching ProxyCallOperations, then:
- CUDA: emits __global__ for kernels, <<<grid,block>>> at call sites
- Metal: emits kernel void with [[thread_position_in_threadgroup]] etc.
Host functions that call kernels become extern "C" (CUDA) or plain
functions (Metal). Non-kernel internal functions become __device__.
Example usage:
static auto myKernel = NautilusFunction{"vecAdd",
[](val<float*> a, val<float*> b, val<float*> c) {
auto tid = gpu::threadIdx_x();
c[tid] = a[tid] + b[tid];
}};
// In traced function:
gpu::launch(myKernel, config, a, b, c);
https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
Bug 1 (Critical): Fix AddOp using processBinary<ir::AndOperation> instead
of processBinary<ir::AddOperation> in CPP, CUDA, and Metal backends.
The wrong static_cast was undefined behavior (happened to work due to
identical BinaryOperation memory layout).
Bug 2 (High): Fix CUDA name collision where root kernel and host wrapper
were both named "execute". Kernel now gets "_kernel" suffix:
__global__ void execute_kernel(...) { ... }
extern "C" void execute(...) { execute_kernel<<<grid,block>>>(...); }
Bug 3 (Design): Remove LaunchConfig struct from gpu::launch() API since
it was silently ignored. Grid/block dims are configured via
engine::Options ("gpu.gridDimX", "gpu.blockDimX", etc.).
Bug 4 (High): Fix Metal backend emitting invalid MSL where a non-kernel
function called a kernel function. MSL kernel functions are entry points
only. Now non-kernel root functions (host code) are skipped in MSL
output - only kernel functions are emitted in the .metal file.
https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
Fix missing space before '=' in negate and not operations across CPP, CUDA, and Metal backends (generated 'var= ~x' instead of 'var = ~x'). Fix GPU codegen test flakiness by finding the most recently written dump file instead of deleting the shared /tmp/dump directory which races with other parallel tests. https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
Move all GPU code (API, CUDA/Metal backends, tests) from the core nautilus library into a separate plugin module at plugins/gpu/. Plugin architecture: - plugins/gpu/ builds as its own library (nautilus-gpu) linked to nautilus - GPU backends register dynamically via CompilationBackendRegistry::registerBackend() using a static initializer (GPUBackendRegistration.cpp) - Plugin has its own test executables and test data - Enabled via -DENABLE_GPU_PLUGIN=ON (off by default) Core library changes: - CompilationBackendRegistry gains public registerBackend() method - getInstance() returns non-const pointer for plugin registration - All GPU #ifdefs, includes, and registrations removed from core - Core library tests pass unchanged (143/143) Plugin structure: plugins/gpu/ ├── include/nautilus/ (gpu.hpp, gpu_intrinsic_targets.hpp) ├── src/ (gpu.cpp, GPUBackendRegistration.cpp) │ ├── cuda/ (CUDA backend) │ └── metal/ (Metal backend) └── test/ (5 test executables + reference data) https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
Fix 1: Extract shared lowering code into GPULoweringProviderBase (CRTP). Common operation dispatch, block processing, and arithmetic/logic/memory handlers are shared. CUDA provider: 685→310 lines. Metal: 536→201 lines. Base class: 473 lines. Net 30% reduction in total code. Fix 2: Delete orphaned nautilus/test/data/gpu-tests/ (moved to plugin). Fix 3: Add gpu_backends.hpp with lowerToCUDA()/lowerToMetal() public API. GPUCodegenTest now calls lowering providers directly instead of fragile dump-file-scanning via JITCompiler. Eliminates parallel test flakiness. Fix 4: Add whole-archive linker flag to prevent dead-stripping of the static GPUBackendRegistration initializer. Fix 5: Add install() target for nautilus-gpu library and headers. Fix 6: Improve gpu::launch() documentation explaining the full mechanism, how grid/block dimensions are configured via engine::Options, and the CPU fallback semantics. Fix 7: format.sh already covers plugins/ via git ls-files. No change needed. https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
gpu::launch() now takes GridDim and BlockDim as traced val<uint32_t>
values, enabling dynamic launch configuration computed from data:
auto blocks = (n + val<uint32_t>(255)) / val<uint32_t>(256);
gpu::launch(kernel, gpu::GridDim{blocks}, gpu::BlockDim{256}, args...);
Implementation:
- gpu::launch() calls setGrid(x,y,z) and setBlock(x,y,z) before the
kernel call. These are extern "C" functions (no-op on CPU) that trace
as ProxyCallOperations via invoke().
- GPU backends register setGrid/setBlock as intrinsics that capture the
argument variable names into pending launch config state.
- When the backend encounters the next kernel call, it uses the captured
variable names in the launch syntax instead of static Options values.
- Falls back to engine::Options if no setGrid/setBlock precedes the call.
CUDA generates: vecAdd<<<dim3(var_$8,1,1),dim3(256,1,1)>>>(args)
where var_$8 = (n + 255) / 256 is computed at runtime.
Also fixes GPU_TEST_DATA_FOLDER macro collision with core TEST_DATA_FOLDER.
https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
… intrinsics setGrid/setBlock intrinsics were incorrectly causing host functions that call gpu::launch() to be classified as kernels. Now only device intrinsics (threadIdx, blockIdx, syncThreads, etc.) mark a function as a kernel. Launch config intrinsics are host-side and handled separately. Fixes Metal output for gpuLaunchVecAddDynamic: the host function that computes grid dims and launches the kernel is now correctly omitted from MSL (only the kernel function is emitted). https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
6ff4c57 to
03ca011
Compare
The Metal backend now produces two outputs:
- .metal: MSL kernel source (device code, as before)
- .cpp: C++ host source with Objective-C Metal API dispatch
The host code uses Metal API to:
1. Create MTLDevice, MTLLibrary, MTLComputePipelineState
2. Set buffer arguments via [encoder setBytes:...]
3. Dispatch with dynamic grid/block dims from setGrid/setBlock
4. Wait for completion via [cmdBuf waitUntilCompleted]
The MetalCompilationBackend compiles the host .cpp via CPPCompiler
(the .metal is compiled separately via xcrun metal on macOS).
API change: lowerToMetal() now returns MetalOutput{deviceCode, hostCode}
instead of a single string. Codegen tests verify both outputs.
https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3
03ca011 to
317192f
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
SYNC_THREADS, SHARED_ALLOCA) to tracing Op enum and IR OperationType enum
threadIdx/blockIdx return 0, blockDim/gridDim return 1, syncThreads is no-op
configurable grid/block dims via engine options, unified memory model
[[thread_position_in_threadgroup]] etc. attributes, threadgroup shared memory
https://claude.ai/code/session_01FPxB7DqvnhQzbLMi7LWVr3