diff --git a/modules/util/triton_mm_8bit.py b/modules/util/triton_mm_8bit.py index 8250e1270..522959d89 100644 --- a/modules/util/triton_mm_8bit.py +++ b/modules/util/triton_mm_8bit.py @@ -47,7 +47,7 @@ ) @triton.jit -def __mm_kernel( +def _mm_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, @@ -109,7 +109,7 @@ def mm_8bit(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: def grid(META): return (triton.cdiv(N, META['BLOCK_SIZE_N']) , triton.cdiv(M, META['BLOCK_SIZE_M']), ) - __mm_kernel[grid]( + _mm_kernel[grid]( a, b, c, M, N, K, a.stride(0), a.stride(1),