diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 82a6058bf..f5c21364d 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -96,6 +96,38 @@ inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias return scale * accum + sum * bias; } +template +inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { + static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); + + if (bits == 2) { + U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; + for (int i = 0; i < (values_per_thread / 4); i++) { + result[4*i] += x * (s[0] * (w[i] & 0x03) + bias); + result[4*i+1] += x * (s[1] * (w[i] & 0x0c) + bias); + result[4*i+2] += x * (s[2] * (w[i] & 0x30) + bias); + result[4*i+3] += x * (s[3] * (w[i] & 0xc0) + bias); + } + } + + else if (bits == 4) { + const thread uint16_t* ws = (const thread uint16_t*)w; + U s[4] = {scale, scale / 16.0f, scale / 256.0f, scale / 4096.0f}; + for (int i = 0; i < (values_per_thread / 4); i++) { + result[4*i] += x * (s[0] * (ws[i] & 0x000f) + bias); + result[4*i+1] += x * (s[1] * (ws[i] & 0x00f0) + bias); + result[4*i+2] += x * (s[2] * (ws[i] & 0x0f00) + bias); + result[4*i+3] += x * (s[3] * (ws[i] & 0xf000) + bias); + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + result[i] += x * (scale * w[i] + bias); + } + } +} + template [[kernel]] void qmv_fast( const device uint32_t* w [[buffer(0)]], @@ -268,7 +300,7 @@ template } -template +template [[kernel]] void qvm( const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(1)]], @@ -278,39 +310,28 @@ template ::acc_t U; - threadgroup U scales_block[BM * groups_per_block]; - threadgroup U biases_block[BM * groups_per_block]; - threadgroup U x_block[BM]; + typedef float U; thread uint32_t w_local; - thread U result[el_per_int] = {0}; + thread U result[pack_factor] = {0}; thread U scale = 1; thread U bias = 0; thread U x_local = 0; // Adjust positions - const int out_vec_size_w = out_vec_size / el_per_int; + const int out_vec_size_w = out_vec_size / pack_factor; const int out_vec_size_g = out_vec_size / group_size; - int out_col_start = tid.y * (BN * el_per_int); - int out_col = out_col_start + simd_gid * el_per_int; - w += out_col / el_per_int; - scales += out_col_start / group_size; - biases += out_col_start / group_size; + int out_col = tid.y * (num_simdgroups * pack_factor) + simd_gid * pack_factor; + w += out_col / pack_factor; + scales += out_col / group_size; + biases += out_col / group_size; x += tid.z * in_vec_size; y += tid.z * out_vec_size + out_col; @@ -318,53 +339,39 @@ template (w_local & bitmask) + bias) * x_local; - w_local >>= bits; - } + qouter((thread uint8_t *)&w_local, x_local, scale, bias, result); } + if (static_cast(i + simd_lid) < in_vec_size) { + x_local = x[i + simd_lid]; + scale = scales[(i + simd_lid) * out_vec_size_g]; + bias = biases[(i + simd_lid) * out_vec_size_g]; + w_local = w[(i + simd_lid) * out_vec_size_w]; + } else { + x_local = 0; + scale = 0; + bias = 0; + w_local = 0; + } + qouter((thread uint8_t *)&w_local, x_local, scale, bias, result); // Accumulate in the simdgroup #pragma clang loop unroll(full) - for (int k=0; k(result[k]); } } @@ -738,7 +745,7 @@ instantiate_qmv_types( 32, 8) #define instantiate_qvm(name, itype, group_size, bits) \ template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \ - [[kernel]] void qvm( \ + [[kernel]] void qvm( \ const device itype* x [[buffer(0)]], \ const device uint32_t* w [[buffer(1)]], \ const device itype* scales [[buffer(2)]], \ @@ -747,7 +754,6 @@ instantiate_qmv_types( 32, 8) const constant int& in_vec_size [[buffer(5)]], \ const constant int& out_vec_size [[buffer(6)]], \ uint3 tid [[threadgroup_position_in_grid]], \ - uint lid [[thread_index_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_lid [[thread_index_in_simdgroup]]); diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index eaddfe6e8..58e42ed9b 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -137,7 +137,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); - int bo = std::min(32, O); + int bo = 8; int bd = 32; MTL::Size group_dims = MTL::Size(bd, bo, 1); MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);