From c52d1600f02fdb0ed76a2442aba7d870718bd65c Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Mon, 29 Jul 2024 15:11:38 -0700 Subject: [PATCH] Fused Affine Quantize/Dequantize ops (#1282) * Add fast affine dequantize * add full quantize kernel * fused kernel with scale/bias computation * fix docstring * fix no jit error * fix test * test fix * reduce fast api to only affine_quantize --- mlx/backend/metal/kernels/quantized.h | 165 ++++++++++++- mlx/backend/metal/kernels/quantized.metal | 278 ++++------------------ mlx/backend/metal/quantized.cpp | 87 ++++++- mlx/backend/no_metal/primitives.cpp | 1 + mlx/fast.cpp | 249 +++++++++++++++++++ mlx/fast.h | 22 ++ mlx/fast_primitives.h | 30 +++ mlx/ops.cpp | 156 +----------- python/src/fast.cpp | 44 ++++ python/tests/test_fast.py | 12 + python/tests/test_quantized.py | 11 +- 11 files changed, 655 insertions(+), 400 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 28a055576..f4d750e58 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -690,12 +690,12 @@ METAL_FUNC void qvm_impl( template < typename T, - const int BM, - const int BK, - const int BN, const int group_size, const int bits, - const bool aligned_N> + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> METAL_FUNC void qmm_t_impl( const device T* x, const device uint32_t* w, @@ -812,11 +812,11 @@ METAL_FUNC void qmm_t_impl( template < typename T, - const int BM, - const int BK, - const int BN, const int group_size, - const int bits> + const int bits, + const int BM = 32, + const int BK = 32, + const int BN = 32> METAL_FUNC void qmm_n_impl( const device T* x, const device uint32_t* w, @@ -1099,7 +1099,7 @@ template < threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; - qmm_t_impl( + qmm_t_impl( x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); } @@ -1131,7 +1131,7 @@ template < threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; - qmm_n_impl( + qmm_n_impl( x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); } @@ -1382,7 +1382,7 @@ template < s_strides, b_strides, tid); - qmm_t_impl( + qmm_t_impl( x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); } @@ -1450,6 +1450,147 @@ template < s_strides, b_strides, tid); - qmm_n_impl( + qmm_n_impl( x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); } + +template +[[kernel]] void affine_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + device T* scales [[buffer(2)]], + device T* biases [[buffer(3)]], + uint index [[thread_position_in_grid]]) { + constexpr T eps = T(1e-7); + constexpr int simd_size = 32; + constexpr int uint8_bits = 8; + constexpr T n_bins = (1 << bits) - 1; + constexpr int packs_per_int = uint8_bits / bits; + constexpr int values_per_reduce = group_size / simd_size; + constexpr int writes_per_reduce = packs_per_int / values_per_reduce; + constexpr int writes_per_pack = + writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int; + + static_assert( + group_size % simd_size == 0, + "Group size must be divisible by simd size."); + + int in_index = index * values_per_reduce; + int out_index = index * writes_per_pack; + + T w_thread[values_per_reduce]; + T w_min = Limits::max; + T w_max = 0; + +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + T val = w[in_index + i]; + w_thread[i] = val; + w_min = min(w_min, val); + w_max = max(w_max, val); + } + + w_min = simd_min(w_min); + w_max = simd_max(w_max); + + T scale = max((w_max - w_min) / n_bins, eps); + bool side = abs(w_min) > abs(w_max); + scale = side ? scale : -scale; + T edge = side ? w_min : w_max; + T q0 = round(edge / scale); + bool at_zero = q0 == 0.0f; + scale = at_zero ? scale : edge / q0; + T bias = at_zero ? T(0) : edge; + + // Write out the scales and biases + int gindex = in_index / group_size; + if (in_index % group_size == 0) { + scales[gindex] = scale; + biases[gindex] = bias; + } + + uint8_t output = 0; +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); + if (bits == 8) { + output = val; + } else { + output += val << (bits * (i % packs_per_int)); + } + + if (packs_per_int < values_per_reduce && + i % packs_per_int == packs_per_int - 1) { + out[out_index + i / packs_per_int] = output; + output = 0; + } else { +#pragma clang loop unroll(full) + for (int j = 0; j < writes_per_reduce - 1; j++) { + uint8_t sval = simd_shuffle_down(val, j + 1); + output += sval << (bits * (values_per_reduce + j + i)); + } + } + } + if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { + out[out_index / writes_per_reduce] = output; + } +} + +template +[[kernel]] void affine_quantize_scales_biases( + const device T* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + device uint8_t* out [[buffer(3)]], + uint index [[thread_position_in_grid]]) { + constexpr int uint8_bits = 8; + constexpr int packs_per_int = uint8_bits / bits; + constexpr T n_bins = (1 << bits) - 1; + + int in_index = index * packs_per_int; + int gindex = in_index / group_size; + T scale = scales[gindex]; + T bias = biases[gindex]; + + uint8_t output = 0; +#pragma clang loop unroll(full) + for (int i = 0; i < packs_per_int; i++) { + uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins); + if (bits == 8) { + output = val; + } else { + output += val << (bits * i); + } + } + out[index] = output; +} + +template +[[kernel]] void affine_dequantize( + const device uint8_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + device T* out [[buffer(3)]], + uint index [[thread_position_in_grid]]) { + constexpr int uint8_bits = 8; + constexpr int packs_per_int = uint8_bits / bits; + + int oindex = index * packs_per_int; + int gindex = oindex / group_size; + T scale = scales[gindex]; + T bias = biases[gindex]; + uint val = w[index]; + +#pragma clang loop unroll(full) + for (int i = 0; i < packs_per_int; i++) { + uint8_t d; + if (bits == 2) { + d = (val >> (bits * i)) & 0x03; + } else if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[oindex + i] = scale * d + bias; + } +} diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 0651db872..130cdda22 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -5,241 +5,67 @@ #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/quantized.h" - -#define instantiate_qmv_fast(itype, group_size, bits) \ +#define instantiate_quantized(name, type, group_size, bits) \ instantiate_kernel( \ - "qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \ - qmv_fast, \ - itype, \ + #name "_" #type "_gs_" #group_size "_b_" #bits, \ + name, \ + type, \ group_size, \ bits) -#define instantiate_qmv_fast_types(group_size, bits) \ - instantiate_qmv_fast(float, group_size, bits) \ - instantiate_qmv_fast(float16_t, group_size, bits) \ - instantiate_qmv_fast(bfloat16_t, group_size, bits) +#define instantiate_quantized_types(name, group_size, bits) \ + instantiate_quantized(name, float, group_size, bits) \ + instantiate_quantized(name, float16_t, group_size, bits) \ + instantiate_quantized(name, bfloat16_t, group_size, bits) -instantiate_qmv_fast_types(128, 2) -instantiate_qmv_fast_types(128, 4) -instantiate_qmv_fast_types(128, 8) -instantiate_qmv_fast_types( 64, 2) -instantiate_qmv_fast_types( 64, 4) -instantiate_qmv_fast_types( 64, 8) -instantiate_qmv_fast_types( 32, 2) -instantiate_qmv_fast_types( 32, 4) -instantiate_qmv_fast_types( 32, 8) +#define instantiate_quantized_groups(name, bits) \ + instantiate_quantized_types(name, 128, bits) \ + instantiate_quantized_types(name, 64, bits) \ + instantiate_quantized_types(name, 32, bits) -#define instantiate_qmv(itype, group_size, bits) \ - instantiate_kernel( \ - "qmv_" #itype "_gs_" #group_size "_b_" #bits, \ - qmv, \ - itype, \ - group_size, \ - bits) +#define instantiate_quantized_all(name) \ + instantiate_quantized_groups(name, 2) \ + instantiate_quantized_groups(name, 4) \ + instantiate_quantized_groups(name, 8) -#define instantiate_qmv_types(group_size, bits) \ - instantiate_qmv(float, group_size, bits) \ - instantiate_qmv(float16_t, group_size, bits) \ - instantiate_qmv(bfloat16_t, group_size, bits) +instantiate_quantized_all(qmv_fast) +instantiate_quantized_all(qmv) +instantiate_quantized_all(qvm) +instantiate_quantized_all(qmm_n) +instantiate_quantized_all(bs_qmv_fast) +instantiate_quantized_all(bs_qmv) +instantiate_quantized_all(bs_qvm) +instantiate_quantized_all(bs_qmm_n) +instantiate_quantized_all(affine_quantize) +instantiate_quantized_all(affine_quantize_scales_biases) +instantiate_quantized_all(affine_dequantize) -instantiate_qmv_types(128, 2) -instantiate_qmv_types(128, 4) -instantiate_qmv_types(128, 8) -instantiate_qmv_types( 64, 2) -instantiate_qmv_types( 64, 4) -instantiate_qmv_types( 64, 8) -instantiate_qmv_types( 32, 2) -instantiate_qmv_types( 32, 4) -instantiate_qmv_types( 32, 8) +#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \ + name, \ + type, \ + group_size, \ + bits, \ + aligned) -#define instantiate_qvm(itype, group_size, bits) \ - instantiate_kernel( \ - "qvm_" #itype "_gs_" #group_size "_b_" #bits, \ - qvm, \ - itype, \ - group_size, \ - bits) +#define instantiate_quantized_types_aligned(name, group_size, bits) \ + instantiate_quantized_aligned(name, float, group_size, bits, true) \ + instantiate_quantized_aligned(name, float16_t, group_size, bits, true) \ + instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, true) \ + instantiate_quantized_aligned(name, float, group_size, bits, false) \ + instantiate_quantized_aligned(name, float16_t, group_size, bits, false) \ + instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, false) -#define instantiate_qvm_types(group_size, bits) \ - instantiate_qvm(float, group_size, bits) \ - instantiate_qvm(float16_t, group_size, bits) \ - instantiate_qvm(bfloat16_t, group_size, bits) +#define instantiate_quantized_groups_aligned(name, bits) \ + instantiate_quantized_types_aligned(name, 128, bits) \ + instantiate_quantized_types_aligned(name, 64, bits) \ + instantiate_quantized_types_aligned(name, 32, bits) -instantiate_qvm_types(128, 2) -instantiate_qvm_types(128, 4) -instantiate_qvm_types(128, 8) -instantiate_qvm_types( 64, 2) -instantiate_qvm_types( 64, 4) -instantiate_qvm_types( 64, 8) -instantiate_qvm_types( 32, 2) -instantiate_qvm_types( 32, 4) -instantiate_qvm_types( 32, 8) +#define instantiate_quantized_all_aligned(name) \ + instantiate_quantized_groups_aligned(name, 2) \ + instantiate_quantized_groups_aligned(name, 4) \ + instantiate_quantized_groups_aligned(name, 8) \ -#define instantiate_qmm_t(itype, group_size, bits, aligned_N) \ - instantiate_kernel( \ - "qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \ - qmm_t, \ - itype, \ - group_size, \ - bits, \ - aligned_N) - -#define instantiate_qmm_t_types(group_size, bits) \ - instantiate_qmm_t(float, group_size, bits, false) \ - instantiate_qmm_t(float16_t, group_size, bits, false) \ - instantiate_qmm_t(bfloat16_t, group_size, bits, false) \ - instantiate_qmm_t(float, group_size, bits, true) \ - instantiate_qmm_t(float16_t, group_size, bits, true) \ - instantiate_qmm_t(bfloat16_t, group_size, bits, true) - -instantiate_qmm_t_types(128, 2) -instantiate_qmm_t_types(128, 4) -instantiate_qmm_t_types(128, 8) -instantiate_qmm_t_types( 64, 2) -instantiate_qmm_t_types( 64, 4) -instantiate_qmm_t_types( 64, 8) -instantiate_qmm_t_types( 32, 2) -instantiate_qmm_t_types( 32, 4) -instantiate_qmm_t_types( 32, 8) - -#define instantiate_qmm_n(itype, group_size, bits) \ - instantiate_kernel( \ - "qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \ - qmm_n, \ - itype, \ - group_size, \ - bits) - -#define instantiate_qmm_n_types(group_size, bits) \ - instantiate_qmm_n(float, group_size, bits) \ - instantiate_qmm_n(float16_t, group_size, bits) \ - instantiate_qmm_n(bfloat16_t, group_size, bits) - -instantiate_qmm_n_types(128, 2) -instantiate_qmm_n_types(128, 4) -instantiate_qmm_n_types(128, 8) -instantiate_qmm_n_types( 64, 2) -instantiate_qmm_n_types( 64, 4) -instantiate_qmm_n_types( 64, 8) -instantiate_qmm_n_types( 32, 2) -instantiate_qmm_n_types( 32, 4) -instantiate_qmm_n_types( 32, 8) - -#define instantiate_bs_qmv_fast(itype, group_size, bits) \ - instantiate_kernel( \ - "bs_qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \ - bs_qmv_fast, \ - itype, \ - group_size, \ - bits) - -#define instantiate_bs_qmv_fast_types(group_size, bits) \ - instantiate_bs_qmv_fast(float, group_size, bits) \ - instantiate_bs_qmv_fast(float16_t, group_size, bits) \ - instantiate_bs_qmv_fast(bfloat16_t, group_size, bits) - -instantiate_bs_qmv_fast_types(128, 2) -instantiate_bs_qmv_fast_types(128, 4) -instantiate_bs_qmv_fast_types(128, 8) -instantiate_bs_qmv_fast_types( 64, 2) -instantiate_bs_qmv_fast_types( 64, 4) -instantiate_bs_qmv_fast_types( 64, 8) -instantiate_bs_qmv_fast_types( 32, 2) -instantiate_bs_qmv_fast_types( 32, 4) -instantiate_bs_qmv_fast_types( 32, 8) - -#define instantiate_bs_qmv(itype, group_size, bits) \ - instantiate_kernel( \ - "bs_qmv_" #itype "_gs_" #group_size "_b_" #bits, \ - bs_qmv, \ - itype, \ - group_size, \ - bits) - -#define instantiate_bs_qmv_types(group_size, bits) \ - instantiate_bs_qmv(float, group_size, bits) \ - instantiate_bs_qmv(float16_t, group_size, bits) \ - instantiate_bs_qmv(bfloat16_t, group_size, bits) - -instantiate_bs_qmv_types(128, 2) -instantiate_bs_qmv_types(128, 4) -instantiate_bs_qmv_types(128, 8) -instantiate_bs_qmv_types( 64, 2) -instantiate_bs_qmv_types( 64, 4) -instantiate_bs_qmv_types( 64, 8) -instantiate_bs_qmv_types( 32, 2) -instantiate_bs_qmv_types( 32, 4) -instantiate_bs_qmv_types( 32, 8) - -#define instantiate_bs_qvm(itype, group_size, bits) \ - instantiate_kernel( \ - "bs_qvm_" #itype "_gs_" #group_size "_b_" #bits, \ - bs_qvm, \ - itype, \ - group_size, \ - bits) - -#define instantiate_bs_qvm_types(group_size, bits) \ - instantiate_bs_qvm(float, group_size, bits) \ - instantiate_bs_qvm(float16_t, group_size, bits) \ - instantiate_bs_qvm(bfloat16_t, group_size, bits) - -instantiate_bs_qvm_types(128, 2) -instantiate_bs_qvm_types(128, 4) -instantiate_bs_qvm_types(128, 8) -instantiate_bs_qvm_types( 64, 2) -instantiate_bs_qvm_types( 64, 4) -instantiate_bs_qvm_types( 64, 8) -instantiate_bs_qvm_types( 32, 2) -instantiate_bs_qvm_types( 32, 4) -instantiate_bs_qvm_types( 32, 8) - -#define instantiate_bs_qmm_t(itype, group_size, bits, aligned_N) \ - instantiate_kernel( \ - "bs_qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \ - bs_qmm_t, \ - itype, \ - group_size, \ - bits, \ - aligned_N) - -#define instantiate_bs_qmm_t_types(group_size, bits) \ - instantiate_bs_qmm_t(float, group_size, bits, false) \ - instantiate_bs_qmm_t(float16_t, group_size, bits, false) \ - instantiate_bs_qmm_t(bfloat16_t, group_size, bits, false) \ - instantiate_bs_qmm_t(float, group_size, bits, true) \ - instantiate_bs_qmm_t(float16_t, group_size, bits, true) \ - instantiate_bs_qmm_t(bfloat16_t, group_size, bits, true) - -instantiate_bs_qmm_t_types(128, 2) -instantiate_bs_qmm_t_types(128, 4) -instantiate_bs_qmm_t_types(128, 8) -instantiate_bs_qmm_t_types( 64, 2) -instantiate_bs_qmm_t_types( 64, 4) -instantiate_bs_qmm_t_types( 64, 8) -instantiate_bs_qmm_t_types( 32, 2) -instantiate_bs_qmm_t_types( 32, 4) -instantiate_bs_qmm_t_types( 32, 8) - -#define instantiate_bs_qmm_n(itype, group_size, bits) \ - instantiate_kernel( \ - "bs_qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \ - bs_qmm_n, \ - itype, \ - group_size, \ - bits) - -#define instantiate_bs_qmm_n_types(group_size, bits) \ - instantiate_bs_qmm_n(float, group_size, bits) \ - instantiate_bs_qmm_n(float16_t, group_size, bits) \ - instantiate_bs_qmm_n(bfloat16_t, group_size, bits) - -instantiate_bs_qmm_n_types(128, 2) -instantiate_bs_qmm_n_types(128, 4) -instantiate_bs_qmm_n_types(128, 8) -instantiate_bs_qmm_n_types( 64, 2) -instantiate_bs_qmm_n_types( 64, 4) -instantiate_bs_qmm_n_types( 64, 8) -instantiate_bs_qmm_n_types( 32, 2) -instantiate_bs_qmm_n_types( 32, 4) -instantiate_bs_qmm_n_types( 32, 8) // clang-format on +instantiate_quantized_all_aligned(qmm_t) +instantiate_quantized_all_aligned(bs_qmm_t) // clang-format on diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index f0a64d5e6..17dbd02d1 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -7,6 +7,7 @@ #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/utils.h" +#include "mlx/fast_primitives.h" #include "mlx/primitives.h" namespace mlx::core { @@ -47,8 +48,8 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { std::ostringstream kname; auto type_string = get_type_string(x.dtype()); - kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_ - << "_fast"; + kname << "qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_" + << bits_; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -270,8 +271,8 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { std::ostringstream kname; auto type_string = get_type_string(x.dtype()); - kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_" - << bits_ << "_fast"; + kname << "bs_qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_" + << bits_; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -513,4 +514,82 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { } } +void fast::AffineQuantize::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + bool compute_scale_bias = inputs.size() == 1; + + auto& w_pre = inputs[0]; + auto& out = outputs[0]; + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& s = stream(); + auto& d = metal::device(s.device); + + std::vector copies; + auto ensure_row_contiguous = [&copies, &s](const array& arr) { + if (arr.flags().row_contiguous) { + return arr; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + return arr_copy; + } + }; + auto w = ensure_row_contiguous(w_pre); + + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_input_array(w, 0); + if (!compute_scale_bias) { + auto& scales_pre = inputs[1]; + auto& biases_pre = inputs[2]; + auto scales = ensure_row_contiguous(scales_pre); + auto biases = ensure_row_contiguous(biases_pre); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_output_array(out, 3); + } else { + auto& scales = outputs[1]; + auto& biases = outputs[2]; + scales.set_data(allocator::malloc_or_wait(scales.nbytes())); + biases.set_data(allocator::malloc_or_wait(biases.nbytes())); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_output_array(scales, 2); + compute_encoder.set_output_array(biases, 3); + } + + std::ostringstream kname; + auto type_string = dequantize_ ? get_type_string(out.dtype()) + : get_type_string(w_pre.dtype()); + auto kernel_func = "affine_quantize_scales_biases"; + if (dequantize_) { + kernel_func = "affine_dequantize"; + } else if (compute_scale_bias) { + kernel_func = "affine_quantize"; + } + kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_" + << bits_; + auto template_def = get_template_definition( + kname.str(), kernel_func, type_string, group_size_, bits_); + auto kernel = get_quantized_kernel(d, kname.str(), template_def); + compute_encoder->setComputePipelineState(kernel); + + // Treat uint32 as uint8 in kernel + constexpr int uint8_per_uint32 = 4; + constexpr int simd_size = 32; + int packs_per_int = 8 / bits_; + int per_thread = compute_scale_bias ? group_size_ / simd_size : packs_per_int; + size_t nthreads = + dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread; + + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + auto group_dims = MTL::Size(thread_group_size, 1, 1); + auto grid_dims = MTL::Size(nthreads, 1, 1); + compute_encoder.dispatchThreads(grid_dims, group_dims); + + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); +} + } // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 1410c92e8..8ee6bb1e2 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -118,6 +118,7 @@ NO_GPU_MULTI(RMSNorm) NO_GPU_MULTI(RMSNormVJP) NO_GPU_MULTI(RoPE) NO_GPU(ScaledDotProductAttention) +NO_GPU_MULTI(AffineQuantize) } // namespace fast } // namespace mlx::core diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 135c79490..4a4819b5a 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -610,4 +610,253 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_; } +array pack_and_quantize( + array& packed_w, + const array& scales, + const array& biases, + int group_size, + int bits, + const Stream& s) { + int el_per_int = 32 / bits; + array zero(0, packed_w.dtype()); + array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1 + array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s); + packed_w = astype( + clip( + round(divide(subtract(packed_w, biases, s), scales, s), s), + zero, + n_bins), + uint32); + packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s); + packed_w = sum( + multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); + return packed_w; +} + +std::tuple +affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { + auto s = to_stream(s_); + + if (group_size != 32 && group_size != 64 && group_size != 128) { + std::ostringstream msg; + msg << "[quantize] The requested group size " << group_size + << " is not supported. The supported group sizes are 64 and 128."; + throw std::invalid_argument(msg.str()); + } + + if (bits != 2 && bits != 4 && bits != 8) { + std::ostringstream msg; + msg << "[quantize] The requested number of bits " << bits + << " is not supported. The supported bits are 2, 4 and 8."; + throw std::invalid_argument(msg.str()); + } + + if (w.ndim() < 2) { + std::ostringstream msg; + msg << "[quantize] The matrix to be quantized must have at least 2 dimension " + << "but it has only " << w.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + + if ((w.shape(-1) % group_size) != 0) { + std::ostringstream msg; + msg << "[quantize] The last dimension of the matrix needs to be divisible by " + << "the quantization group size " << group_size + << ". However the provided " << " matrix has shape " << w.shape(); + throw std::invalid_argument(msg.str()); + } + + int el_per_int = 32 / bits; + + if (w.shape(-1) < 32 * el_per_int) { + std::ostringstream msg; + msg << "[quantize] The feature dimension (2nd dimension of the matrix) is " + << "too small for quantization. We support >=512 for 2 bits, " + << ">= 256 for 4 bits and >= 128 for 8 bits. The provided matrix has " + << "shape " << w.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + auto fallback = [group_size, bits, el_per_int, s]( + const std::vector& inputs) -> std::vector { + auto& w = inputs[0]; + auto wshape = w.shape(); + wshape.back() = -1; + + array zero(0, w.dtype()); + array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1 + array eps(1e-7, w.dtype()); + + array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s); + + array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); + array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); + array mask = greater(abs(w_min, s), abs(w_max, s), s); + array scales = + maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); + scales = where(mask, scales, negative(scales), s); + array edge = where(mask, w_min, w_max, s); + array q0 = round(divide(edge, scales, s), s); + scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); + array biases = where(equal(q0, zero, s), zero, edge); + + packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s); + return { + reshape(packed_w, wshape, s), + reshape(scales, wshape, s), + reshape(biases, wshape, s), + }; + }; + + std::vector outputs; + if (s.device == Device::gpu) { + auto wq_shape = w.shape(); + wq_shape.back() = w.shape(-1) / el_per_int; + auto sshape = w.shape(); + sshape.back() = w.shape(-1) / group_size; + outputs = array::make_arrays( + {wq_shape, sshape, sshape}, + {uint32, w.dtype(), w.dtype()}, + std::make_shared(s, fallback, group_size, bits, false), + {w}); + } else { + outputs = fallback({w}); + } + return {outputs[0], outputs[1], outputs[2]}; +} + +array affine_quantize( + const array& w, + const array& scales, + const array& biases, + int group_size, + int bits, + StreamOrDevice s_) { + auto s = to_stream(s_); + + int el_per_int = 32 / bits; + auto fallback = [group_size, bits, el_per_int, s]( + const std::vector& inputs) -> std::vector { + auto& w = inputs[0]; + auto scales = expand_dims(inputs[1], -1, s); + auto biases = expand_dims(inputs[2], -1, s); + + auto wshape = w.shape(); + wshape.back() = -1; + + array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s); + packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s); + return {reshape(packed_w, wshape, s)}; + }; + + if (s.device == Device::gpu) { + auto out_shape = w.shape(); + out_shape.back() = w.shape(-1) / el_per_int; + return array( + out_shape, + uint32, + std::make_shared(s, fallback, group_size, bits, false), + {w, scales, biases}); + } + return fallback({w, scales, biases})[0]; +} + +array affine_dequantize( + const array& w, + const array& scales, + const array& biases, + int group_size, + int bits, + StreamOrDevice s_) { + if (bits <= 0) { + std::ostringstream msg; + msg << "[dequantize] Invalid value for bits: " << bits; + throw std::invalid_argument(msg.str()); + } + if (group_size <= 0) { + std::ostringstream msg; + msg << "[dequantize] Invalid value for group_size: " << group_size; + throw std::invalid_argument(msg.str()); + } + if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) { + std::ostringstream msg; + msg << "[quantize] The matrix to be quantized must have at least 2 dimension " + << "but it has only " << w.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + + auto wshape = w.shape(); + auto sshape = scales.shape(); + auto bshape = biases.shape(); + wshape.back() = -1; + sshape.back() = -1; + bshape.back() = -1; + + if (wshape != sshape || wshape != bshape) { + throw std::invalid_argument( + "[dequantize] Shape of scales and biases does not match the matrix"); + } + + if (w.dtype() != uint32) { + throw std::invalid_argument( + "[dequantize] The matrix should be given as a uint32"); + } + + // Packing into uint32 + int el_per_int = 32 / bits; + + if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) { + std::ostringstream msg; + msg << "[dequantize] Shape of scales and biases does not match the matrix " + << "given the quantization parameters. Provided matrix of shape " + << w.shape() << " and scales/biases of shape " << scales.shape() + << " with group_size=" << group_size << " and bits=" << bits << "."; + throw std::invalid_argument(msg.str()); + } + + auto s = to_stream(s_); + + auto fallback = + [&wshape, &sshape, &scales, &biases, group_size, bits, el_per_int, s]( + const std::vector& inputs) -> std::vector { + auto& w = inputs[0]; + auto& scales = inputs[1]; + auto& biases = inputs[2]; + std::vector parts; + for (int start = 0; start < 32; start += bits) { + int shift_left = 32 - (start + bits); + int shift_right = shift_left + start; + + parts.push_back(expand_dims( + right_shift( + left_shift(w, array(32 - (start + bits), uint32), s), + array(32 - bits, uint32), + s), + -1, + s)); + } + array w_full = concatenate(parts, -1, s); + + // Dequantize + wshape.push_back(group_size); + w_full = reshape(w_full, wshape, s); + w_full = multiply(w_full, expand_dims(scales, -1, s), s); + w_full = add(w_full, expand_dims(biases, -1, s), s); + w_full = reshape(w_full, sshape, s); + + return {w_full}; + }; + + if (s.device == Device::gpu) { + auto out_shape = w.shape(); + out_shape.back() = w.shape(-1) * el_per_int; + return array( + out_shape, + scales.dtype(), + std::make_shared(s, fallback, group_size, bits, true), + {w, scales, biases}); + } + return fallback({w, scales, biases})[0]; +} + } // namespace mlx::core::fast diff --git a/mlx/fast.h b/mlx/fast.h index 4d73de581..4c63df8fe 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -39,4 +39,26 @@ array scaled_dot_product_attention( const std::optional& mask = std::nullopt, StreamOrDevice s = {}); +std::tuple affine_quantize( + const array& w, + int group_size = 64, + int bits = 4, + StreamOrDevice s = {}); + +array affine_quantize( + const array& w, + const array& scales, + const array& biases, + int group_size = 64, + int bits = 4, + StreamOrDevice s = {}); + +array affine_dequantize( + const array& w, + const array& scales, + const array& biases, + int group_size = 64, + int bits = 4, + StreamOrDevice s = {}); + } // namespace mlx::core::fast diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 001bf67c9..1883f5789 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -212,4 +212,34 @@ class ScaledDotProductAttention : public Custom { bool needs_mask_; }; +class AffineQuantize : public Custom { + public: + explicit AffineQuantize( + Stream stream, + std::function(std::vector)> fallback, + int group_size, + int bits, + bool dequantize) + : Custom(stream, fallback), + group_size_(group_size), + bits_(bits), + dequantize_(dequantize) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_PRINT(AffineQuantize); + + private: + std::function(std::vector)> fallback_; + int group_size_; + int bits_; + bool dequantize_; +}; + } // namespace mlx::core::fast diff --git a/mlx/ops.cpp b/mlx/ops.cpp index a57a0df32..d50014ab8 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -6,6 +6,7 @@ #include #include +#include "mlx/fast.h" #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/transforms.h" @@ -3356,89 +3357,7 @@ std::tuple quantize( int group_size /* = 64 */, int bits /* = 4 */, StreamOrDevice s /* = {} */) { - if (group_size != 32 && group_size != 64 && group_size != 128) { - std::ostringstream msg; - msg << "[quantize] The requested group size " << group_size - << " is not supported. The supported group sizes are 64 and 128."; - throw std::invalid_argument(msg.str()); - } - - if (bits != 2 && bits != 4 && bits != 8) { - std::ostringstream msg; - msg << "[quantize] The requested number of bits " << bits - << " is not supported. The supported bits are 2, 4 and 8."; - throw std::invalid_argument(msg.str()); - } - - if (w.ndim() < 2) { - std::ostringstream msg; - msg << "[quantize] The matrix to be quantized must have at least 2 dimension " - << "but it has only " << w.ndim() << "."; - throw std::invalid_argument(msg.str()); - } - - if ((w.shape(-1) % group_size) != 0) { - std::ostringstream msg; - msg << "[quantize] The last dimension of the matrix needs to be divisible by " - << "the quantization group size " << group_size - << ". However the provided " << " matrix has shape " << w.shape(); - throw std::invalid_argument(msg.str()); - } - - // Compute some constants used for the quantization - array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1 - array eps(1e-7, w.dtype()); - array zero(0, w.dtype()); - int el_per_int = 32 / bits; - array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s); - shifts = reshape(shifts, {1, 1, -1}, s); - - // Check that the w matrix will fill up a whole SIMD. - // This is an implementation detail which should be removed in the future but - // at least we bail out early which will result in a nice readable error. - // - // Hopefully nobody is quantizing matrices that small anyway. - if (w.shape(-1) < 32 * el_per_int) { - std::ostringstream msg; - msg << "[quantize] The feature dimension (2nd dimension of the matrix) is " - << "too small for quantization. We support >=512 for 2 bits, " - << ">= 256 for 4 bits and >= 128 for 8 bits. The provided matrix has " - << "shape " << w.shape() << "."; - throw std::invalid_argument(msg.str()); - } - - // Prepare the shape for the outputs. - auto wshape = w.shape(); - wshape.back() = -1; - - // Compute scales and biases - array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s); - array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); - array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); - - array mask = greater(abs(w_min, s), abs(w_max, s), s); - array scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); - scales = where(mask, scales, negative(scales), s); - array edge = where(mask, w_min, w_max, s); - array q0 = round(divide(edge, scales, s), s); - scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); - array biases = where(equal(q0, zero, s), zero, edge); - - // Quantize and pack w - packed_w = astype( - clip( - round(divide(subtract(packed_w, biases, s), scales, s), s), - zero, - n_bins), - uint32); - packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s); - packed_w = sum( - multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); - - return std::make_tuple( - reshape(packed_w, wshape, s), - reshape(scales, wshape, s), - reshape(biases, wshape, s)); + return fast::affine_quantize(w, group_size, bits); } array dequantize( @@ -3448,76 +3367,7 @@ array dequantize( int group_size /* = 64 */, int bits /* = 4 */, StreamOrDevice s /* = {} */) { - if (bits <= 0) { - std::ostringstream msg; - msg << "[dequantize] Invalid value for bits: " << bits; - throw std::invalid_argument(msg.str()); - } - if (group_size <= 0) { - std::ostringstream msg; - msg << "[dequantize] Invalid value for group_size: " << group_size; - throw std::invalid_argument(msg.str()); - } - if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) { - std::ostringstream msg; - msg << "[quantize] The matrix to be quantized must have at least 2 dimension " - << "but it has only " << w.ndim() << "."; - throw std::invalid_argument(msg.str()); - } - - auto wshape = w.shape(); - auto sshape = scales.shape(); - auto bshape = biases.shape(); - wshape.back() = -1; - sshape.back() = -1; - bshape.back() = -1; - - if (wshape != sshape || wshape != bshape) { - throw std::invalid_argument( - "[dequantize] Shape of scales and biases does not match the matrix"); - } - - if (w.dtype() != uint32) { - throw std::invalid_argument( - "[dequantize] The matrix should be given as a uint32"); - } - - // Compute some constants for the dequantization - int el_per_int = 32 / bits; - - if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) { - std::ostringstream msg; - msg << "[dequantize] Shape of scales and biases does not match the matrix " - << "given the quantization parameters. Provided matrix of shape " - << w.shape() << " and scales/biases of shape " << scales.shape() - << " with group_size=" << group_size << " and bits=" << bits << "."; - throw std::invalid_argument(msg.str()); - } - - // Extract the pieces from the passed quantized matrix - std::vector parts; - for (int start = 0; start < 32; start += bits) { - int shift_left = 32 - (start + bits); - int shift_right = shift_left + start; - - parts.push_back(expand_dims( - right_shift( - left_shift(w, array(32 - (start + bits), uint32), s), - array(32 - bits, uint32), - s), - -1, - s)); - } - array w_full = concatenate(parts, -1, s); - - // Dequantize - wshape.push_back(group_size); - w_full = reshape(w_full, wshape, s); - w_full = multiply(w_full, expand_dims(scales, -1, s), s); - w_full = add(w_full, expand_dims(biases, -1, s), s); - w_full = reshape(w_full, sshape, s); - - return w_full; + return fast::affine_dequantize(w, scales, biases, group_size, bits, s); } array gather_qmm( diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 451937b21..be14c1d85 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include "mlx/fast.h" @@ -138,4 +139,47 @@ void init_fast(nb::module_& parent_module) { Returns: array: The output array. )pbdoc"); + + m.def( + "affine_quantize", + nb::overload_cast< + const array&, + const array&, + const array&, + int, + int, + StreamOrDevice>(&fast::affine_quantize), + "w"_a, + "scales"_a, + "biases"_a, + "group_size"_a = 64, + "bits"_a = 4, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def affine_quantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Quantize the matrix ``w`` using the provided ``scales`` and + ``biases`` and the ``group_size`` and ``bits`` configuration. + + Formally, given the notation in :func:`quantize`, we compute + :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and + :math:`\beta` as follows + + .. math:: + + w_i = s (\hat{w_i} + \beta) + + Args: + w (array): Matrix to be quantize + scales (array): The scales to use per ``group_size`` elements of ``w`` + biases (array): The biases to use per ``group_size`` elements of ``w`` + group_size (int, optional): The size of the group in ``w`` that shares a + scale and bias. (default: ``64``) + bits (int, optional): The number of bits occupied by each element in + ``w``. (default: ``4``) + + Returns: + array: The quantized version of ``w`` + )pbdoc"); } diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index b554be55b..8e889c8b4 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -439,6 +439,18 @@ class TestFast(mlx_tests.MLXTestCase): )(x) self.assertTrue(mx.allclose(vmap_out, vmap_fast_out)) + def test_affine_quantize(self): + mx.random.seed(7) + x = mx.random.uniform(shape=(4, 1024)) + for bits in (2, 4, 8): + for group_size in (32, 64, 128): + with self.subTest(bits=bits, group_size=group_size): + w, scales, biases = mx.quantize(x, bits=bits, group_size=group_size) + w_p = mx.fast.affine_quantize( + x, scales, biases, bits=bits, group_size=group_size + ) + self.assertTrue(mx.allclose(w, w_p)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 92ad3d3e7..47c924c59 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -12,11 +12,12 @@ class TestQuantized(mlx_tests.MLXTestCase): w = mx.random.normal(shape=(128, 512)) for gs in [32, 64, 128]: for b in [2, 4, 8]: - w_q, scales, biases = mx.quantize(w, gs, b) - w_hat = mx.dequantize(w_q, scales, biases, gs, b) - errors = (w - w_hat).abs().reshape(*scales.shape, -1) - eps = 1e-6 - self.assertTrue((errors <= (scales[..., None] + eps).abs()).all()) + with self.subTest(gs=gs, b=b): + w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b) + w_hat = mx.dequantize(w_q, scales, biases, gs, b) + errors = (w - w_hat).abs().reshape(*scales.shape, -1) + eps = 1e-6 + self.assertTrue((errors <= (scales[..., None] + eps).abs()).all()) # test quantize/dequantize 0s a = mx.zeros((256, 512))