diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index 532fe7afa..994b58c9b 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -282,7 +282,6 @@ def _( A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int ) -> torch.Tensor: torch._check_is_size(blocksize) - torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}") torch._check( A.dtype in [torch.float16, torch.bfloat16, torch.float32], lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", @@ -312,7 +311,6 @@ def _( out: torch.Tensor, ) -> None: torch._check_is_size(blocksize) - torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}") torch._check( A.dtype in [torch.float16, torch.bfloat16, torch.float32], lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 95a7d9090..c22bcfa45 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -374,6 +374,10 @@ def matmul( return MatMul8bitLt.apply(A, B, out, bias, state) +# Above this limit, inference falls back to the dequantize + GEMM path. +FUSED_4BIT_DEQUANT_LIMIT = 8 + + def matmul_4bit( A: torch.Tensor, B: torch.Tensor, @@ -391,7 +395,8 @@ def matmul_4bit( else: return MatMul4Bit.apply(A, B, out, bias, quant_state) - if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": + num_a_rows = A.numel() // A.shape[-1] + if num_a_rows <= FUSED_4BIT_DEQUANT_LIMIT and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: warn( f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 7799645db..bdf7056c3 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -472,10 +472,11 @@ def _gemv_4bit_impl( # torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") m = ct.c_int32(shapeB[0]) - n = ct.c_int32(1) + num_a_rows = A.numel() // A.shape[-1] + n = ct.c_int32(num_a_rows) k = ct.c_int32(shapeB[1]) - lda = m + lda = ct.c_int32(A.shape[-1]) ldb = ct.c_int32((A.shape[-1] + 1) // 2) ldc = m diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 0d313c8d7..1383448ff 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1446,11 +1446,6 @@ __global__ void kgemm_4bit_inference_naive( int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out, int lda, int ldb, int ldc, int blocksize ) { - - // per threadblock: - // load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps] - // THREADS/BNB_WARP_SIZE warps -> that many loads per iter - // 1xwarp_size * warp_size x warps -> 1 x warps outputs per thread block typedef bnb_cub::WarpReduce WarpReduce; __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / BNB_WARP_SIZE]; @@ -1458,104 +1453,148 @@ __global__ void kgemm_4bit_inference_naive( const int warp_lane = threadIdx.x % BNB_WARP_SIZE; const int row_B = (THREADS / BNB_WARP_SIZE) * blockIdx.x + warp_idx; const int offset_B = ldb * row_B; - const int num_values_8bit = num_values_4bit / 2; - float local_C = 0.0f; + constexpr int num_values_8bit = num_values_4bit / 2; + + float local_C0 = 0.0f; + float local_C1 = 0.0f; + float local_C2 = 0.0f; + float local_C3 = 0.0f; unsigned char local_B_4bit[num_values_8bit]; - T local_B[num_values_4bit / 4]; - T local_A[num_values_4bit / 4]; - __shared__ T quant_map[16]; - T local_absmax = T(0.0f); - - if (threadIdx.x < 16) - quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x])); - // for(int i = threadIdx.x; i < 16; i++) - // quant_map[i] = T(__ldg(&datatype[i])); + __shared__ float quant_map[32]; + float local_absmax = 0.0f; + + if (threadIdx.x < 16) { + float val = __ldg(&datatype[threadIdx.x]); + quant_map[threadIdx.x] = val; + quant_map[threadIdx.x + 16] = val; + } __syncthreads(); - // A: [1, K] - // B: [N, K] - for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE * num_values_4bit) { - const int inner_idx_halved = inner_idx / 2; + if (row_B >= M) return; - // Since blocksize will always be a power-of-2, we avoid more expensive - // division by the blocksize and instead use a shift operation. - // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. - const int absidx = ((2 * offset_B) + inner_idx) >> (31 - __clz(blocksize)); + const int stride = BNB_WARP_SIZE * num_values_4bit; + const int clz_blocksize = 31 - __clz(blocksize); + const int base_absidx = 2 * offset_B; + const int qm_offset = (warp_lane & 1) << 4; - local_absmax = __ldg(&(absmax[absidx])); + for (int n_idx = 0; n_idx < N; n_idx++) { + const T* __restrict__ A_row = A + n_idx * lda; - if (row_B < M) { - if ((inner_idx_halved + num_values_8bit) < (K / 2)) { - // this is the most important for performance considerations - reinterpret_cast(local_B_4bit)[0] = - reinterpret_cast(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; - } else { -#pragma unroll - for (int j = 0; j < (num_values_8bit); j++) - if ((inner_idx_halved) + j < (K / 2)) - local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; - else - local_B_4bit[j] = 0b01110111; - } - } else { -#pragma unroll - for (int j = 0; j < (num_values_8bit); j++) - local_B_4bit[j] = 0b01110111; + local_C0 = 0.0f; + local_C1 = 0.0f; + local_C2 = 0.0f; + local_C3 = 0.0f; + + int inner_idx = warp_lane * num_values_4bit; + int inner_idx_halved = inner_idx >> 1; + int4 prefetch_B; + float prefetch_absmax; + + if (inner_idx < K) { + prefetch_absmax = __ldg(&absmax[(base_absidx + inner_idx) >> clz_blocksize]); + if ((inner_idx_halved + num_values_8bit) < (K >> 1)) + prefetch_B = reinterpret_cast(B)[(offset_B + inner_idx_halved) / num_values_8bit]; } - for (int i = 0; i < 4; i++) { -#pragma unroll - for (int k = 0; k < num_values_8bit / 4; k++) { -#if BNB_BF16_AVAILABLE - local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; - local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; -#else - // bf16 multipliation not supported - local_B[k * 2] = - T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * (float)local_absmax); - local_B[k * 2 + 1] = - T((float)quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * (float)local_absmax); -#endif - } + for (; inner_idx < K; inner_idx += stride) { + inner_idx_halved = inner_idx >> 1; - if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { - // this is also relatively important for performance - if (BITS == 16) { - reinterpret_cast(local_A)[0] = - reinterpret_cast(A)[inner_idx / (num_values_4bit / 4) + i]; - } else { - reinterpret_cast(local_A)[0] = - reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; - reinterpret_cast(local_A)[1] = - reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; - } + local_absmax = prefetch_absmax; - } else + if (__builtin_expect((inner_idx_halved + num_values_8bit) < (K >> 1), 1)) { + reinterpret_cast(local_B_4bit[0]) = prefetch_B; + } else { #pragma unroll - for (int k = 0; k < num_values_4bit / 4; k++) - if (inner_idx + (i * num_values_4bit / 4) + k < K) - local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; - else - local_A[k] = T(0.0f); + for (int j = 0; j < num_values_8bit; j++) + local_B_4bit[j] = ((inner_idx_halved + j) < (K >> 1)) ? B[offset_B + inner_idx_halved + j] : 0x77; + } + + int next_inner_idx = inner_idx + stride; + int next_inner_idx_halved = next_inner_idx >> 1; + if (next_inner_idx < K) { + prefetch_absmax = __ldg(&absmax[(base_absidx + next_inner_idx) >> clz_blocksize]); + if ((next_inner_idx_halved + num_values_8bit) < (K >> 1)) + prefetch_B = reinterpret_cast(B)[(offset_B + next_inner_idx_halved) / num_values_8bit]; + } -// accumulate in float; small performance hit for Ampere, but lower error for outputs + float b0 = quant_map[qm_offset + (local_B_4bit[0] >> 4)] * local_absmax; + float b1 = quant_map[qm_offset + (local_B_4bit[0] & 0xF)] * local_absmax; + float b2 = quant_map[qm_offset + (local_B_4bit[1] >> 4)] * local_absmax; + float b3 = quant_map[qm_offset + (local_B_4bit[1] & 0xF)] * local_absmax; + float b4 = quant_map[qm_offset + (local_B_4bit[2] >> 4)] * local_absmax; + float b5 = quant_map[qm_offset + (local_B_4bit[2] & 0xF)] * local_absmax; + float b6 = quant_map[qm_offset + (local_B_4bit[3] >> 4)] * local_absmax; + float b7 = quant_map[qm_offset + (local_B_4bit[3] & 0xF)] * local_absmax; + float b8 = quant_map[qm_offset + (local_B_4bit[4] >> 4)] * local_absmax; + float b9 = quant_map[qm_offset + (local_B_4bit[4] & 0xF)] * local_absmax; + float b10 = quant_map[qm_offset + (local_B_4bit[5] >> 4)] * local_absmax; + float b11 = quant_map[qm_offset + (local_B_4bit[5] & 0xF)] * local_absmax; + float b12 = quant_map[qm_offset + (local_B_4bit[6] >> 4)] * local_absmax; + float b13 = quant_map[qm_offset + (local_B_4bit[6] & 0xF)] * local_absmax; + float b14 = quant_map[qm_offset + (local_B_4bit[7] >> 4)] * local_absmax; + float b15 = quant_map[qm_offset + (local_B_4bit[7] & 0xF)] * local_absmax; + float b16 = quant_map[qm_offset + (local_B_4bit[8] >> 4)] * local_absmax; + float b17 = quant_map[qm_offset + (local_B_4bit[8] & 0xF)] * local_absmax; + float b18 = quant_map[qm_offset + (local_B_4bit[9] >> 4)] * local_absmax; + float b19 = quant_map[qm_offset + (local_B_4bit[9] & 0xF)] * local_absmax; + float b20 = quant_map[qm_offset + (local_B_4bit[10] >> 4)] * local_absmax; + float b21 = quant_map[qm_offset + (local_B_4bit[10] & 0xF)] * local_absmax; + float b22 = quant_map[qm_offset + (local_B_4bit[11] >> 4)] * local_absmax; + float b23 = quant_map[qm_offset + (local_B_4bit[11] & 0xF)] * local_absmax; + float b24 = quant_map[qm_offset + (local_B_4bit[12] >> 4)] * local_absmax; + float b25 = quant_map[qm_offset + (local_B_4bit[12] & 0xF)] * local_absmax; + float b26 = quant_map[qm_offset + (local_B_4bit[13] >> 4)] * local_absmax; + float b27 = quant_map[qm_offset + (local_B_4bit[13] & 0xF)] * local_absmax; + float b28 = quant_map[qm_offset + (local_B_4bit[14] >> 4)] * local_absmax; + float b29 = quant_map[qm_offset + (local_B_4bit[14] & 0xF)] * local_absmax; + float b30 = quant_map[qm_offset + (local_B_4bit[15] >> 4)] * local_absmax; + float b31 = quant_map[qm_offset + (local_B_4bit[15] & 0xF)] * local_absmax; + + if (__builtin_expect(inner_idx + 32 <= K, 1)) { + int4 a_vec0 = reinterpret_cast(A_row)[inner_idx / 8]; + int4 a_vec1 = reinterpret_cast(A_row)[inner_idx / 8 + 1]; + int4 a_vec2 = reinterpret_cast(A_row)[inner_idx / 8 + 2]; + int4 a_vec3 = reinterpret_cast(A_row)[inner_idx / 8 + 3]; + + const T* a0 = reinterpret_cast(&a_vec0); + const T* a1 = reinterpret_cast(&a_vec1); + const T* a2 = reinterpret_cast(&a_vec2); + const T* a3 = reinterpret_cast(&a_vec3); + + local_C0 += (float)a0[0]*b0; local_C1 += (float)a0[1]*b1; + local_C2 += (float)a0[2]*b2; local_C3 += (float)a0[3]*b3; + local_C0 += (float)a0[4]*b4; local_C1 += (float)a0[5]*b5; + local_C2 += (float)a0[6]*b6; local_C3 += (float)a0[7]*b7; + local_C0 += (float)a1[0]*b8; local_C1 += (float)a1[1]*b9; + local_C2 += (float)a1[2]*b10; local_C3 += (float)a1[3]*b11; + local_C0 += (float)a1[4]*b12; local_C1 += (float)a1[5]*b13; + local_C2 += (float)a1[6]*b14; local_C3 += (float)a1[7]*b15; + local_C0 += (float)a2[0]*b16; local_C1 += (float)a2[1]*b17; + local_C2 += (float)a2[2]*b18; local_C3 += (float)a2[3]*b19; + local_C0 += (float)a2[4]*b20; local_C1 += (float)a2[5]*b21; + local_C2 += (float)a2[6]*b22; local_C3 += (float)a2[7]*b23; + local_C0 += (float)a3[0]*b24; local_C1 += (float)a3[1]*b25; + local_C2 += (float)a3[2]*b26; local_C3 += (float)a3[3]*b27; + local_C0 += (float)a3[4]*b28; local_C1 += (float)a3[5]*b29; + local_C2 += (float)a3[6]*b30; local_C3 += (float)a3[7]*b31; + } else { + float b_vals[32] = {b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,b15, + b16,b17,b18,b19,b20,b21,b22,b23,b24,b25,b26,b27,b28,b29,b30,b31}; #pragma unroll - for (int k = 0; k < num_values_4bit / 4; k++) { -#if BNB_BF16_AVAILABLE - local_C += (float)(local_A[k] * local_B[k]); -#else - // bf16 multipliation not supported - local_C += ((float)local_A[k] * (float)local_B[k]); -#endif + for (int k = 0; k < 32; k++) { + float a_val = (inner_idx + k < K) ? (float)A_row[inner_idx + k] : 0.0f; + local_C0 += a_val * b_vals[k]; + } } } - } - local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + float local_C = local_C0 + local_C1 + local_C2 + local_C3; + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); - if (row_B < M && warp_lane == 0) - out[row_B] = T(local_C); + if (warp_lane == 0) + out[n_idx * ldc + row_B] = T(local_C); + } } template __global__ void kfunc(T* A, T* B, T value, long n) { @@ -1595,6 +1634,18 @@ template __global__ void kgemm_4bit_inference_naive( int M, int N, int K, float* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, float* out, int lda, int ldb, int ldc, int blocksize ); +template __global__ void kgemm_4bit_inference_naive( + int M, int N, int K, half* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, half* out, + int lda, int ldb, int ldc, int blocksize +); +template __global__ void kgemm_4bit_inference_naive( + int M, int N, int K, bnb_bfloat16* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, + bnb_bfloat16* out, int lda, int ldb, int ldc, int blocksize +); +template __global__ void kgemm_4bit_inference_naive( + int M, int N, int K, float* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, + float* out, int lda, int ldb, int ldc, int blocksize +); template __global__ void kdequant_mm_int32_fp16<4, 512>( int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, diff --git a/csrc/ops.cu b/csrc/ops.cu index e76834785..b796722ab 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -421,15 +421,21 @@ void gemm_4bit_inference_naive( int blocksize, bnb_stream_t stream ) { - int num_blocks = (m + 3) / 4; #if BNB_HIP - if (bnb_host_warp_size() == 64) { - num_blocks = (m + 1) / 2; + const int ws = bnb_host_warp_size(); + int num_blocks = (m + 1) / 2; + if (ws == 32) { + kgemm_4bit_inference_naive + <<>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + } else { + kgemm_4bit_inference_naive + <<>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } -#endif - +#else + int num_blocks = (m + 3) / 4; kgemm_4bit_inference_naive <<>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); +#endif BNB_CHECK_RETURN(BNB_PEEK_LAST_ERROR()); }