[ROCm] Optimize kgemm_4bit_inference_naive for ROCm, use it for batch sizes other than 1#1920
Conversation
|
Hi @sstamenk ! Thank you for sharing. This is very timely as I am actually working on the same thing right now for NVIDIA hardware. I have a SIMT kernel intended to actually replace Secondly I've been taking it a step further and have tensor core MMA kernels as well, which are compatible with NVIDIA GPUs from Turing and newer. These can extend the wins well beyond M=16. Porting that will be quite a bit more involved for ROCm. I'm developing heuristics to decide which kernels to dispatch, much like what you're doing here, but also with much more nuance as there's up to 7 fused kernel choices (SIMT and 6 tensor core variants). So, in principle, I agree with doing this. However I will also be taking some care to refine the custom op interface for this too. It's not quite ready yet, but are you open to taking a look at the SIMT version of the kernel I've been working on, and seeing how well that can be ported and compare? If anything we could use whichever version is faster on AMD hardware, but I do plan to refine the custom op interface for this a little bit too. I can provide more detail during the week. |
|
Hi @matthewdouglas, Glad to hear that this is something you've flagged as well as a possible improvement. The heuristic for kernel selection is one thing that also came up during my investigation as well so it's good to hear that it's something you're working on already. The kernel solutions you mention sound very interesting as well, I'm open to taking a look and testing/helping with porting to AMD hardware! MFMA/WMMA kernels might be a more involved endeavor as you correctly pointed out due to the architecture differences, let's discuss this more once you share more details. Is there a Bitsandbytes Discord or Slack channel where we can shift the discussion for easier communication or would you like to keep it to the PR page? |
|
@sstamenk We have a shared Slack channel managed by AMD called #bitsandbytes-amd-collab. I think someone on the AMD side would need to invite you. It's been very inactive, but would be a good place to collaborate. Something to point out though is that vLLM is currently considering deprecation/removal of bitsandbytes support: To give a little more detail on my WIP kernel replacement for
Some quick numbers from RTX 4090, M=1:
For M=2-7, the average win over the dequant+cublas path is 2.8x, but can be as high as 9.6x. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Based on issues raised in #1842 and pytorch#171687.
Summary
kgemm_4bit_inference_naiveon ROCm, following the suggestions discussed in # 70B 4-bit LLM decode bottlenecked by HIP kernel (kgemm_4bit_inference_naive) efficiency — 49% vs 91% memory bandwidth on ROCm/gfx1151 #1842.Req = 2.M > 1) inputs instead of only the vector case.Technical details
This PR makes two related changes.
Kernel optimization
kgemm_4bit_inference_naiveto reduce overhead in the fused dequantize + matmul path on ROCm.kgemm_4bit_inference_naive) efficiency — 49% vs 91% memory bandwidth on ROCm/gfx1151 #1842Fused path support for
M > 1M == 1.Up to a platform-specific crossover point, launching the fused kernel is substantially faster than falling back to split dequantize + GEMM. This matters most for serving workloads, where decode steps regularly hit small
M > 1batches.Example measured on Strix Halo:
At larger
M, the fused path eventually converges with and then regresses against split dequantize + GEMM. The crossover differs by GPU:gfx115116RTX 50908-12gfx120110-12MI308X4-6For this PR, the dispatch threshold is set to
M=8as a cross-GPU compromise. That still leaves some regressions on MI308X oncereqs >= 6, but avoids the larger regressions seen at higher thresholds on other GPUs.Testing plan
gemm_4bitunit tests to validate correctness of the updated kernel path.M > 1fused-path performance gain.Testing results
gemv_4bit unit-tests
kgemm_4bit_inference_naivebenchmarkIn this table,
Adenotes the baseline kernel andBdenotes the optimized kernel.gfx11511133 us740 us117 GB/s178 GB/s~210 GB/s (measured)56%85%1.53xRTX 509086 us84 us1361 GB/s1394 GB/s~1,790 GB/s76%78%1.02xgfx1201539 us226 us218 GB/s519 GB/s640 GB/s34%81%2.39xMI308X656 us246 us179 GB/s477 GB/s~3,277 GB/s5.5%14.6%2.67xEnd-to-end Transformers Throughput
Strix Halo(gfx1151):2.453.861.58x18.327.91.53x10.715.61.46x9.612.51.30x17.422.31.28xRTX 5090:85.5784.320.99x82.4380.760.98x59.5158.950.99xRadeon AI Pro R9700(gfx1201):38.4264.321.67x31.3146.271.48x23.8332.001.34xMI308X(gfx942):31.2840.511.30x30.9740.461.31x23.4529.031.24x15.6215.941.02x4.5910.602.31x4.5910.612.31xEnd-to-end vLLM Serving Throughput for
Reqs > 1Strix Halo(gfx1151)Mistral-7B
22.5fused34.7fused35.3fused35.7fused1.59x10.4split10.4split51.8fused53.4fused5.13x20.3split20.3split67.6fused67.3fused3.32x30.4split30.5split72.4fused71.9fused2.37x40.4split40.4split75.2fused75.1fused1.86x50.5split50.6split50.6split76.7fused1.52x60.2split60.4split60.4split77.9fused1.29x69.9split70.1split70.1split78.7fused1.13x80.1split80.2split80.3split79.2fused0.99xLlama-8B
20.7fused32.0fused32.0fused32.0fused1.55x10.4split10.4split48.9fused47.3fused4.55x20.3split20.3split63.7fused63.6fused3.13x30.3split30.2split68.8fused68.5fused2.26x40.2split40.1split72.4fused72.4fused1.80x50.2split50.2split50.2split74.6fused1.49x60.0split60.0split59.9split75.8fused1.26x69.5split69.6split69.6split76.9fused1.11x79.5split79.5split79.5split77.2fused0.97xQwen3.5-9B
17.5fused23.1fused22.9fused22.9fused1.31x9.4split9.4split40.4fused39.5fused4.20x18.4split18.4split55.8fused57.0fused3.10x26.9split27.0split61.8fused62.0fused2.30x35.4split35.5split65.9fused66.0fused1.86x44.5split44.6split44.7split68.7fused1.54x52.0split52.3split52.3split70.6fused1.36x60.1split60.3split60.4split72.1fused1.20x69.0split69.3split69.5split72.9fused1.06xLlama-3.3-70B
2.5fused4.1fused4.1fused4.0fused1.60x1.2split1.2split5.9fused5.8fused4.83x2.4split-7.4fused7.4fused3.08x3.6split-7.8fused7.8fused2.17x4.8split-8.0fused8.0fused1.67x6.0split-6.0split8.1fused1.35x7.2split--8.2fused1.14x8.4split--8.3fused0.99x9.5split--8.3fused0.87xRTX 5090Mistral-7B
134.9fused136.0fused135.3fused131.4fused1.00x0.97x104.0split103.9split255.5fused243.8fused2.46x2.34x204.9split204.6split347.0fused343.1fused1.69x1.67x283.0split283.0split385.1fused382.0fused1.36x1.35x375.8split376.0split404.5fused401.9fused1.08x1.07x422.4split422.4split420.9split407.0fused1.00x0.96x469.1split468.4split469.3split411.9fused1.00x0.88x558.4split559.8split558.6split415.9fused1.00x0.74x736.8split736.7split737.3split425.5fused1.00x0.58xLlama-8B
136.6fused134.1fused133.0fused133.3fused0.97x0.98x101.5split101.4split251.4fused245.6fused2.48x2.42x200.0split199.5split333.6fused330.7fused1.67x1.65x275.7split275.7split373.8fused374.3fused1.36x1.36x365.9split365.0split394.3fused395.4fused1.08x1.08x410.7split410.8split411.0split399.3fused1.00x0.97x456.0split456.0split456.8split404.5fused1.00x0.89x544.9split545.2split545.4split410.1fused1.00x0.75x720.7split720.5split720.7split420.5fused1.00x0.58xQwen3.5-9B
72.3fused72.6fused73.4fused72.2fused1.02x1.00x100.0split100.0split135.3fused132.4fused1.35x1.32x188.7split188.5split271.1fused264.8fused1.44x1.40x280.0split280.0split344.4fused343.2fused1.23x1.23x370.4split370.4split369.3fused368.2fused1.00x0.99x415.6split415.7split415.6split375.0fused1.00x0.90x462.1split462.1split462.4split382.2fused1.00x0.83x545.8split545.8split545.9split390.6fused1.00x0.72x737.9split738.1split738.1split400.8fused1.00x0.54xRadeon AI Pro R9700(gfx1201)Mistral-7B
45.5fused87.2fused90.0fused88.1fused1.98x1.94x34.1split34.2split127.9fused119.6fused3.75x3.51x68.1split67.7split150.4fused147.1fused2.21x2.16x134.2split134.2split166.6fused163.9fused1.24x1.22x151.2split150.7split151.0split167.2fused1.00x1.11x167.7split167.0split167.1split167.3fused1.00x1.00x184.1split183.3split183.6split166.9fused1.00x0.91x199.0split198.7split199.7split169.3fused1.00x0.85x263.2split261.5split262.6split170.2fused1.00x0.65xLlama-8B
44.1fused80.9fused80.9fused79.7fused1.84x1.81x33.5split33.4split117.8fused112.1fused3.52x3.35x66.6split66.4split142.1fused140.7fused2.13x2.11x132.4split132.0split159.7fused160.6fused1.21x1.21x147.3split147.1split147.1split162.5fused1.00x1.10x163.1split162.9split162.8split162.4fused1.00x1.00x179.2split178.4split178.7split160.4fused1.00x0.90x195.1split194.9split194.8split165.5fused1.00x0.85x256.1split255.8split256.3split167.5fused1.00x0.65xQwen3.5-9B
9.4fused10.8fused10.8fused10.8fused1.15x1.15x13.4split13.5split19.8fused19.7fused1.48x1.47x26.8split26.7split36.3fused36.3fused1.35x1.35x52.7split52.7split61.3fused61.1fused1.16x1.16x59.3split59.2split59.3split64.1fused1.00x1.08x65.4split65.4split65.4split69.2fused1.00x1.06x72.3split72.3split72.2split73.6fused1.00x1.02x78.7split78.6split78.6split78.0fused1.00x0.99x103.2split103.1split103.2split92.1fused1.00x0.89xMI308X(gfx942)Mistral-7B
37.6fused61.3fused64.2fused61.7fused1.63x1.71x1.64x47.3split48.0split98.8fused92.1fused1.01x2.09x1.95x94.4split95.1split112.8fused111.1fused1.01x1.19x1.18x141.3split142.0split120.2fused118.8fused1.00x0.85x0.84x188.5split189.2split124.0fused123.2fused1.00x0.66x0.65x214.0split215.6split192.7split124.6fused1.01x0.90x0.58x237.5split239.0split238.9split125.9fused1.01x1.01x0.53x284.8split287.1split285.6split128.9fused1.01x1.00x0.45x379.2split382.0split381.3split130.1fused1.01x1.01x0.34xLlama-8B
37.3fused63.9fused64.9fused62.1fused1.71x1.74x1.66x47.1split47.6split96.3fused91.5fused1.01x2.04x1.94x90.8split93.7split111.7fused110.3fused1.03x1.23x1.21x139.8split139.0split117.9fused117.6fused0.99x0.84x0.84x186.9split185.7split122.6fused121.6fused0.99x0.66x0.65x212.8split213.2split211.8split123.7fused1.00x1.00x0.58x235.6split237.0split236.9split125.3fused1.01x1.01x0.53x283.2split283.3split284.5split127.0fused1.00x1.00x0.45x376.5split377.2split377.6split129.8fused1.00x1.00x0.34xQwen3.5-9B
30.2fused30.0fused30.1fused29.8fused0.99x1.00x0.99x40.0split39.8split57.6fused58.8fused1.00x1.44x1.47x79.7split78.9split96.9fused96.5fused0.99x1.22x1.21x119.4split119.0split107.0fused106.2fused1.00x0.90x0.89x159.4split159.0split112.2fused112.0fused1.00x0.70x0.70x179.7split178.9split177.8split114.2fused1.00x0.99x0.64x199.3split182.9split197.7split115.9fused0.92x0.99x0.58x239.6split236.7split224.1split118.5fused0.99x0.94x0.49x317.6split318.1split315.2split122.0fused1.00x0.99x0.38xLlama-3.3-70B
4.7fused11.3fused11.3fused10.7fused2.40x2.40x2.28x5.4split5.4split12.6fused11.9fused1.00x2.33x2.20x10.7split10.7split13.8fused13.5fused1.00x1.29x1.26x16.1split16.1split14.3fused14.1fused1.00x0.89x0.88x21.4split21.4split14.6fused14.4fused1.00x0.68x0.67x24.2split24.2split24.1split14.7fused1.00x1.00x0.61x26.9split26.9split26.9split14.8fused1.00x1.00x0.55x32.1split32.1split32.1split14.9fused1.00x1.00x0.46x42.7split42.7split42.7split14.9fused1.00x1.00x0.35x