From 8da1c64fe95c4a9395696bba08906a452e50af08 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 20 Aug 2025 17:18:47 -0700 Subject: [PATCH] cpu mxfp4 --- mlx/backend/cpu/quantized.cpp | 447 ++++++++++++++++++---- mlx/backend/metal/kernels/fp4_quantized.h | 24 +- mlx/fast.cpp | 44 --- 3 files changed, 379 insertions(+), 136 deletions(-) diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index f9cfb347d..75ff35a2e 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -13,6 +13,35 @@ namespace mlx::core { namespace { +const static float MXFP4_LUT[16] = { + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f}; + +template +static inline T dequantize_scale(uint8_t s) { + using FOrI = union { + bfloat16_t f; + uint16_t i; + }; + FOrI out; + out.i = (s == 0 ? 0x40 : (static_cast(s) << 7)); + return static_cast(out.f); +} + inline constexpr short get_pack_factor(int bits, int wsize = 8) { return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); } @@ -407,50 +436,230 @@ void _qmm_dispatch( } } -// template -// void _qmm_mxfp4_dispatch_typed( -// array& out, -// const array& x, -// const array& w, -// const array& scales, -// bool transposed_w) { -// int K = x.shape(-1); -// int M = x.ndim() > 1 ? x.shape(-2) : 1; -// int N = out.shape(-1); -// int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0; -// int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; -// int batch_size = x.size() / (K * M); -// -// auto out_ptr = out.data(); -// auto x_ptr = x.data(); -// auto w_ptr = w.data(); -// auto scales_ptr = scales.data(); -// for (int i = 0; i < batch_size; i++) { -// _qmm_mxfp4_dispatch_typed( -// out_ptr + i * M * N, -// x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()), -// w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()), -// scales_ptr + elem_to_loc(i * g_els, scales.shape(), -// scales.strides()), M, N, K, transposed_w); -// } -// } -// -// -// void _qmm_mxfp4_dispatch( -// array& out, -// const array& x, -// const array& w, -// const array& scales, -// bool transposed_w) { -// switch (x.dtype()) { -// case bfloat16: -// _qmm_mxfp4_dispatch_typed(out, x, w, scales, transposed_w); -// break; -// default: -// throw std::invalid_argument( -// "[quantized_matmul] only bfloat is supported for mxfp4"); -// } -// } +template +void mxfp4_qmm( + T* result, + const T* x, + const uint32_t* w, + const uint8_t* scales, + int M, + int N, + int K) { + constexpr int group_size = 32; + constexpr int pack_factor = get_pack_factor(4, 8); + constexpr int bytes_per_pack = get_bytes_per_pack(4); + constexpr int packs_in_group = group_size / pack_factor; + + for (int m = 0; m < M; m++) { + const uint8_t* w_local = (const uint8_t*)w; + const uint8_t* scales_local = scales; + + std::fill(result, result + N, 0); + + for (int k = 0; k < K; k++) { + T* result_local = result; + T xi = *x++; + + for (int n = 0; n < N; n += group_size) { + T scale = dequantize_scale(*scales_local++); + for (int ng = 0; ng < packs_in_group; ng++) { + uint8_t wi = *w_local++; +#pragma clang loop unroll(full) + for (int p = 0; p < pack_factor; p++) { + (*result_local++) += + xi * scale * static_cast(MXFP4_LUT[wi & 0xf]); + wi >>= 4; + } + } + } + } + + result += N; + } +} + +template +void mxfp4_qmm_t( + T* result, + const T* x, + const uint32_t* w, + const uint8_t* scales, + int M, + int N, + int K) { + constexpr int group_size = 32; + constexpr int pack_factor = get_pack_factor(4, 8); + constexpr int bytes_per_pack = get_bytes_per_pack(4); + constexpr int packs_in_group = group_size / pack_factor; + + for (int m = 0; m < M; m++) { + const uint8_t* w_local = (const uint8_t*)w; + const uint8_t* scales_local = scales; + + for (int n = 0; n < N; n++) { + const T* x_local = x; + T sum = 0; + for (int k = 0; k < K; k += group_size) { + T scale = dequantize_scale(*scales_local++); + + T gsum = 0; + for (int kw = 0; kw < packs_in_group; kw++) { + uint8_t wi = *w_local++; +#pragma clang loop unroll(full) + for (int p = 0; p < pack_factor; p++) { + gsum += (*x_local++) * static_cast(MXFP4_LUT[wi & 0xf]); + wi >>= 4; + } + } + sum += scale * gsum; + } + *result = sum; + result++; + } + + x += K; + } +} + +template +simd::Simd mxfp4_extract_bits_simd(const uint32_t* w) { + if constexpr (S == 8) { + constexpr std::array shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}}; + auto shifts(*(simd::Simd*)&shifts_); + auto wi = simd::Simd(*w); + wi = wi >> shifts; + wi = wi & 0xf; + simd::Simd w_out; + for (int i = 0; i < S; ++i) { + w_out[i] = MXFP4_LUT[wi[i]]; + } + return w_out; + } else { + // Appease compiler.. but should never get here + throw std::runtime_error("Unsupported combination for simd qmm."); + } +} + +template +void mxfp4_qmm_t_simd( + T* result, + const T* x, + const uint32_t* w, + const uint8_t* scales, + int M, + int N, + int K) { + constexpr int group_size = 32; + constexpr int pack_factor = 32 / 4; + constexpr int packs_in_group = group_size / pack_factor; + constexpr int S = simd::max_size; + static_assert( + S % pack_factor == 0, "SIMD size must be divisible by pack factor"); + constexpr int packs_per_simd = S / pack_factor; + + for (int m = 0; m < M; m++) { + const uint32_t* w_local = w; + const uint8_t* scales_local = scales; + + for (int n = 0; n < N; n++) { + simd::Simd acc(0); + auto x_local = x; + for (int k = 0; k < K; k += group_size) { + T scale = dequantize_scale(*scales_local++); + + simd::Simd g_acc(0); + for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) { + // Extract bits + auto wf = mxfp4_extract_bits_simd(w_local); + w_local += packs_per_simd; + simd::Simd x_simd = simd::load(x_local); + g_acc = g_acc + x_simd * wf; + x_local += S; + } + acc = acc + scale * g_acc; + } + + *result = T(simd::sum(acc)); + result++; + } + x += K; + } +} + +template +void mxfp4_qmm_dispatch_transpose( + T* result, + const T* x, + const uint32_t* w, + const uint8_t* scales, + int M, + int N, + int K, + bool transposed_w) { + if (transposed_w) { + // the simd size must be a multiple of the number of elements per word + if constexpr (simd::max_size % 8 == 0) { + mxfp4_qmm_t_simd(result, x, w, scales, M, N, K); + } else { + mxfp4_qmm_t(result, x, w, scales, M, N, K); + } + } else { + mxfp4_qmm(result, x, w, scales, M, N, K); + } +} + +template +void mxfp4_qmm_dispatch_typed( + array& out, + const array& x, + const array& w, + const array& scales, + bool transposed_w) { + int K = x.shape(-1); + int M = x.ndim() > 1 ? x.shape(-2) : 1; + int N = out.shape(-1); + int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0; + int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; + int batch_size = x.size() / (K * M); + + auto out_ptr = out.data(); + auto x_ptr = x.data(); + auto w_ptr = w.data(); + auto scales_ptr = scales.data(); + for (int i = 0; i < batch_size; i++) { + mxfp4_qmm_dispatch_transpose( + out_ptr + i * M * N, + x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()), + w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()), + scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()), + M, + N, + K, + transposed_w); + } +} + +void mxfp4_qmm_dispatch( + array& out, + const array& x, + const array& w, + const array& scales, + bool transposed_w) { + switch (x.dtype()) { + case bfloat16: + mxfp4_qmm_dispatch_typed(out, x, w, scales, transposed_w); + break; + case float16: + mxfp4_qmm_dispatch_typed(out, x, w, scales, transposed_w); + break; + case float32: + mxfp4_qmm_dispatch_typed(out, x, w, scales, transposed_w); + break; + default: + throw std::invalid_argument( + "[quantized_matmul] only floating types are supported"); + } +} template void _bs_qmm_dispatch_typed( @@ -558,6 +767,74 @@ void _bs_qmm_dispatch( } } +template +void mxfp4_bs_qmm_dispatch_typed( + array& out, + const array& x, + const array& w, + const array& scales, + const array& lhs_indices, + const array& rhs_indices, + bool transposed_w) { + int K = x.shape(-1); + int M = x.shape(-2); + int N = out.shape(-1); + + int w_els = w.shape(-1) * w.shape(-2); + int g_els = scales.shape(-1) * scales.shape(-2); + + auto out_ptr = out.data(); + auto x_ptr = x.data(); + auto w_ptr = w.data(); + auto scales_ptr = scales.data(); + auto lhs_indices_ptr = lhs_indices.data(); + auto rhs_indices_ptr = rhs_indices.data(); + + for (int i = 0; i < lhs_indices.size(); i++) { + int x_idx = lhs_indices_ptr[elem_to_loc( + i, lhs_indices.shape(), lhs_indices.strides())]; + int w_idx = rhs_indices_ptr[elem_to_loc( + i, rhs_indices.shape(), rhs_indices.strides())]; + mxfp4_qmm_dispatch_transpose( + out_ptr + i * M * N, + x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()), + w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()), + scales_ptr + + elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()), + M, + N, + K, + transposed_w); + } +} + +void mxfp4_bs_qmm_dispatch( + array& out, + const array& x, + const array& w, + const array& scales, + const array& lhs_indices, + const array& rhs_indices, + bool transposed_w) { + switch (x.dtype()) { + case float32: + mxfp4_bs_qmm_dispatch_typed( + out, x, w, scales, lhs_indices, rhs_indices, transposed_w); + break; + case float16: + mxfp4_bs_qmm_dispatch_typed( + out, x, w, scales, lhs_indices, rhs_indices, transposed_w); + break; + case bfloat16: + mxfp4_bs_qmm_dispatch_typed( + out, x, w, scales, lhs_indices, rhs_indices, transposed_w); + break; + default: + throw std::invalid_argument( + "[quantized_matmul] only floating types are supported"); + } +} + } // namespace void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { @@ -604,15 +881,13 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); }); } else { - // encoder.dispatch([out = array::unsafe_weak_copy(out), - // x = array::unsafe_weak_copy(x), - // w = array::unsafe_weak_copy(w), - // scales = array::unsafe_weak_copy(scales), - // group_size_ = group_size_, - // bits_ = bits_, - // transpose_ = transpose_]() mutable { - // _qmm_mxfp4_dispatch(out, x, w, scales, transpose_); - // }); + encoder.dispatch([out = array::unsafe_weak_copy(out), + x = array::unsafe_weak_copy(x), + w = array::unsafe_weak_copy(w), + scales = array::unsafe_weak_copy(scales), + transpose_ = transpose_]() mutable { + mxfp4_qmm_dispatch(out, x, w, scales, transpose_); + }); } } @@ -622,9 +897,8 @@ void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { auto& x_pre = inputs[0]; auto& w_pre = inputs[1]; auto& scales_pre = inputs[2]; - auto& biases_pre = inputs[3]; - auto& lhs_indices = inputs[4]; - auto& rhs_indices = inputs[5]; + auto& lhs_indices = inputs[inputs.size() - 2]; + auto& rhs_indices = inputs[inputs.size() - 1]; std::vector temps; auto ensure_row_contiguous_last_dims = [s = stream(), @@ -643,7 +917,6 @@ void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { auto x = ensure_row_contiguous_last_dims(x_pre); auto w = ensure_row_contiguous_last_dims(w_pre); auto scales = ensure_row_contiguous_last_dims(scales_pre); - auto biases = ensure_row_contiguous_last_dims(biases_pre); out.set_data(allocator::malloc(out.nbytes())); @@ -652,32 +925,46 @@ void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(scales); - encoder.set_input_array(biases); encoder.set_input_array(lhs_indices); encoder.set_input_array(rhs_indices); encoder.set_output_array(out); - encoder.dispatch([out = array::unsafe_weak_copy(out), - x = array::unsafe_weak_copy(x), - w = array::unsafe_weak_copy(w), - scales = array::unsafe_weak_copy(scales), - biases = array::unsafe_weak_copy(biases), - lhs_indices = array::unsafe_weak_copy(lhs_indices), - rhs_indices = array::unsafe_weak_copy(rhs_indices), - group_size_ = group_size_, - bits_ = bits_, - transpose_ = transpose_]() mutable { - _bs_qmm_dispatch( - out, - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - group_size_, - bits_, - transpose_); - }); + if (mode_ == "affine") { + auto biases = ensure_row_contiguous_last_dims(inputs[3]); + encoder.set_input_array(biases); + encoder.dispatch([out = array::unsafe_weak_copy(out), + x = array::unsafe_weak_copy(x), + w = array::unsafe_weak_copy(w), + scales = array::unsafe_weak_copy(scales), + biases = array::unsafe_weak_copy(biases), + lhs_indices = array::unsafe_weak_copy(lhs_indices), + rhs_indices = array::unsafe_weak_copy(rhs_indices), + group_size_ = group_size_, + bits_ = bits_, + transpose_ = transpose_]() mutable { + _bs_qmm_dispatch( + out, + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + group_size_, + bits_, + transpose_); + }); + } else { + encoder.dispatch([out = array::unsafe_weak_copy(out), + x = array::unsafe_weak_copy(x), + w = array::unsafe_weak_copy(w), + scales = array::unsafe_weak_copy(scales), + lhs_indices = array::unsafe_weak_copy(lhs_indices), + rhs_indices = array::unsafe_weak_copy(rhs_indices), + transpose_ = transpose_]() mutable { + mxfp4_bs_qmm_dispatch( + out, x, w, scales, lhs_indices, rhs_indices, transpose_); + }); + } } template diff --git a/mlx/backend/metal/kernels/fp4_quantized.h b/mlx/backend/metal/kernels/fp4_quantized.h index 1c261b3c0..40c5fa187 100644 --- a/mlx/backend/metal/kernels/fp4_quantized.h +++ b/mlx/backend/metal/kernels/fp4_quantized.h @@ -95,10 +95,10 @@ inline U qdot( const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (values_per_thread / 4); i++) { accum += - (x_thread[4 * i] * lut[ws[i] & 0x000f] + - x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0x000f] + - x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0x000f] + - x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0x000f]); + (x_thread[4 * i] * lut[ws[i] & 0xf] + + x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] + + x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] + + x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]); } return scale * accum; } @@ -115,10 +115,10 @@ inline U qdot_safe( const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (N / 4); i++) { accum += - (x_thread[4 * i] * lut[ws[i] & 0x000f] + - x_thread[4 * i + 1] * lut[(ws[i] & 0x00f0) >> 4] + - x_thread[4 * i + 2] * lut[(ws[i] & 0x0f00) >> 8] + - x_thread[4 * i + 3] * lut[(ws[i] & 0xf000) >> 12]); + (x_thread[4 * i] * lut[ws[i] & 0xf] + + x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] + + x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] + + x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]); } return scale * accum; } @@ -131,8 +131,8 @@ inline void qouter( thread U* result, const threadgroup U* lut) { for (int i = 0; i < (values_per_thread / 2); i++) { - result[2 * i] += x * scale * lut[w[i] & 0x0f]; - result[2 * i + 1] += x * scale * lut[(w[i] & 0xf0) >> 4]; + result[2 * i] += x * scale * lut[w[i] & 0xf]; + result[2 * i + 1] += x * scale * lut[(w[i] >> 4) & 0xf]; } } @@ -143,8 +143,8 @@ inline void dequantize( threadgroup U* w_local, const threadgroup U* lut) { for (int i = 0; i < (N / 2); i++) { - w_local[2 * i] = scale * lut[w[i] & 0x0f]; - w_local[2 * i + 1] = scale * lut[(w[i] & 0xf0) >> 4]; + w_local[2 * i] = scale * lut[w[i] & 0xf]; + w_local[2 * i + 1] = scale * lut[(w[i] >> 4) & 0xf]; } } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 2917b1584..befc9f80c 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -762,50 +762,6 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_; } -array pack_and_quantize( - array& packed_w, - const array& scales, - const array& biases, - int bits, - const Stream& s) { - int el_per_int = 32 / bits; - array zero(0, packed_w.dtype()); - array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1 - packed_w = astype( - clip( - round(divide(subtract(packed_w, biases, s), scales, s), s), - zero, - n_bins, - s), - uint32, - s); - if (is_power_of_2(bits)) { - array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s); - packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s); - packed_w = sum( - multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); - } else { - // This is slow but we have fast GPU/CPU versions of this function so we - // shouldn't be here often. - packed_w = expand_dims(packed_w, /* axis= */ -1, s); - packed_w = bitwise_and( - right_shift(packed_w, arange(bits, uint32, s), s), - array({1}, uint32), - s); - auto new_shape = packed_w.shape(); - new_shape[new_shape.size() - 2] = -1; - new_shape.back() = 32; - packed_w = reshape(packed_w, new_shape, s); - array shifts = arange(32, uint32, s); - packed_w = - sum(left_shift(packed_w, shifts, s), - /* axis= */ -1, - /* keepdims= */ false, - s); - } - return packed_w; -} - bool Quantize::is_equivalent(const Primitive& other) const { const Quantize& p_other = static_cast(other); return (