From 14b4e51a7c6455a61a74d24da9f47dfeb161023f Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 5 Mar 2024 17:32:19 -0800 Subject: [PATCH] Improved quantized matrix vector product (#786) --- mlx/backend/metal/kernels/quantized.metal | 316 +++++++++++++++++----- mlx/backend/metal/quantized.cpp | 33 ++- mlx/ops.cpp | 9 + 3 files changed, 284 insertions(+), 74 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index c2bfba9f9..82a6058bf 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -23,7 +23,143 @@ template <> struct AccT { typedef float acc_t; }; -template + +template +inline U load_vector(const device T *x, thread U *x_thread) { + static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i+1] + x[i+2] + x[i+3]; + x_thread[i] = x[i]; + x_thread[i+1] = x[i+1] / 4.0f; + x_thread[i+2] = x[i+2] / 16.0f; + x_thread[i+3] = x[i+3] / 64.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i+1] + x[i+2] + x[i+3]; + x_thread[i] = x[i]; + x_thread[i+1] = x[i+1] / 16.0f; + x_thread[i+2] = x[i+2] / 256.0f; + x_thread[i+3] = x[i+3] / 4096.0f; + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + } + + return sum; +} + +template +inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum) { + static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += ( + x_thread[4*i] * (w[i] & 0x03) + + x_thread[4*i+1] * (w[i] & 0x0c) + + x_thread[4*i+2] * (w[i] & 0x30) + + x_thread[4*i+3] * (w[i] & 0xc0)); + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += ( + x_thread[4*i] * (ws[i] & 0x000f) + + x_thread[4*i+1] * (ws[i] & 0x00f0) + + x_thread[4*i+2] * (ws[i] & 0x0f00) + + x_thread[4*i+3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +[[kernel]] void qmv_fast( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = 32 / bits; + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; + w += out_row * in_vec_size_w + simd_lid * packs_per_thread; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.z * in_vec_size + simd_lid * values_per_thread; + y += tid.z * out_vec_size + out_row; + + for (int k = 0; k < in_vec_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot(wl, x_thread, s, b, sum); + } + + w += block_size / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } +} + + +template [[kernel]] void qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -33,91 +169,101 @@ template ::acc_t U; - threadgroup U scales_block[BM * groups_per_block]; - threadgroup U biases_block[BM * groups_per_block]; - threadgroup U x_block[colgroup]; - - thread uint32_t w_local; - thread U result = 0; - thread U scale = 1; - thread U bias = 0; - thread U x_thread[el_per_thread]; + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; // Adjust positions - const int in_vec_size_w = in_vec_size / el_per_thread; + const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_g = in_vec_size / group_size; - int out_row = tid.y * BM + simd_gid; - w += out_row * in_vec_size_w; - scales += out_row * in_vec_size_g; - biases += out_row * in_vec_size_g; - x += tid.z * in_vec_size; - y += tid.z * out_vec_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; + const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); if (out_row >= out_vec_size) { return; } - // Loop over in_vec in blocks of colgroup - for (int i=0; i(x, x_thread); + + for (int row = 0; out_row + row < out_vec_size; row++) { + const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot(wl, x_thread, s, b, sum); + } + + w += block_size / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; } - scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size]; - bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size]; - // Load the matrix elements - w_local = w[i / el_per_thread + simd_lid]; - - // Do all the work. - #pragma clang loop unroll(full) - for (int k=0; k(w_local & bitmask) + bias) * x_thread[k]; - w_local >>= bits; + for (int row = 0; out_row + row < out_vec_size; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } } } - // Accumulate in the simdgroup - result = simd_sum(result); + // In this case the last tile is moved back to redo some output values + else { + w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread; + scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.z * in_vec_size + simd_lid * values_per_thread; + y += tid.z * out_vec_size + used_out_row; - // Store the result - if (simd_lid == 0) { - y[out_row] = static_cast(result); + for (int k = 0; k < in_vec_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot(wl, x_thread, s, b, sum); + } + + w += block_size / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } } } @@ -532,9 +678,38 @@ template ( \ +#define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \ + template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits "_fast")]] \ + [[kernel]] void qmv_fast( \ + const device uint32_t* w [[buffer(0)]], \ + const device itype* scales [[buffer(1)]], \ + const device itype* biases [[buffer(2)]], \ + const device itype* x [[buffer(3)]], \ + device itype* y [[buffer(4)]], \ + const constant int& in_vec_size [[buffer(5)]], \ + const constant int& out_vec_size [[buffer(6)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \ + instantiate_qmv_fast(float32, float, group_size, bits, packs_per_thread) \ + instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \ + instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread) + +instantiate_qmv_fast_types(128, 2, 1) +instantiate_qmv_fast_types(128, 4, 2) +instantiate_qmv_fast_types(128, 8, 2) +instantiate_qmv_fast_types( 64, 2, 1) +instantiate_qmv_fast_types( 64, 4, 2) +instantiate_qmv_fast_types( 64, 8, 2) +instantiate_qmv_fast_types( 32, 2, 1) +instantiate_qmv_fast_types( 32, 4, 2) +instantiate_qmv_fast_types( 32, 8, 2) + +#define instantiate_qmv(name, itype, group_size, bits) \ + template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \ + [[kernel]] void qmv( \ const device uint32_t* w [[buffer(0)]], \ const device itype* scales [[buffer(1)]], \ const device itype* biases [[buffer(2)]], \ @@ -543,7 +718,6 @@ template & inputs, array& out) { int B = x.size() / D; int O = out.shape(-1); if (transpose_) { + // Route to the fast qmv kernel that has no bounds checking + if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { + std::ostringstream kname; + kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" + << bits_ << "_fast"; + + // Encode and dispatch kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int bo = 8; + int bd = 32; + MTL::Size group_dims = MTL::Size(bd, 2, 1); + MTL::Size grid_dims = MTL::Size(1, O / bo, B); + + set_array_buffer(compute_encoder, w, 0); + set_array_buffer(compute_encoder, scales, 1); + set_array_buffer(compute_encoder, biases, 2); + set_array_buffer(compute_encoder, x, 3); + set_array_buffer(compute_encoder, out, 4); + compute_encoder->setBytes(&D, sizeof(int), 5); + compute_encoder->setBytes(&O, sizeof(int), 6); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } + // Route to the qmv kernel - if (B < 6) { + else if (B < 6) { std::ostringstream kname; kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" << bits_; @@ -52,9 +79,9 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); - int bo = std::min(32, O); + int bo = 8; int bd = 32; - MTL::Size group_dims = MTL::Size(bd, bo, 1); + MTL::Size group_dims = MTL::Size(bd, 2, 1); MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B); set_array_buffer(compute_encoder, w, 0); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e0e11f4c1..75159335f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3032,6 +3032,15 @@ array quantized_matmul( } auto dtype = result_type({x, scales, biases}); + if (!is_floating_point(dtype) || is_complex(dtype)) { + std::ostringstream msg; + msg << "[quantized_matmul] Only real floating types are supported but " + << "the passed types where x.dtype() == " << x.dtype() + << ", scales.dtype() == " << scales.dtype() + << " and biases.dtype() == " << biases.dtype(); + throw std::invalid_argument(msg.str()); + } + auto out = array( {x.shape(0), w_outer_dims}, dtype,