From 70560b6bd5324efbfd9f97c9076c884d21ffac34 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 28 Aug 2025 06:45:26 -0700 Subject: [PATCH] Add mode parameter for quantization (#2499) * add mode parameter for quantization * mxfp4 quantize/dequantize + start of optional biases * mxfp4 works * speedup * cpu mxfp4 * fix * fix test tol * fix * refactor * add quant mode enum --- docs/src/dev/custom_metal_kernels.rst | 2 +- mlx/backend/cpu/quantized.cpp | 459 ++++- mlx/backend/cuda/quantized/quantized.cpp | 4 +- mlx/backend/metal/kernels/CMakeLists.txt | 3 +- mlx/backend/metal/kernels/fp4_quantized.h | 1791 +++++++++++++++++ mlx/backend/metal/kernels/fp4_quantized.metal | 127 ++ mlx/backend/metal/kernels/quantized.h | 112 +- mlx/backend/metal/kernels/quantized.metal | 45 +- mlx/backend/metal/kernels/quantized_utils.h | 90 + mlx/backend/metal/quantized.cpp | 405 ++-- mlx/backend/no_cpu/primitives.cpp | 2 +- mlx/backend/no_gpu/primitives.cpp | 2 +- mlx/export.cpp | 2 +- mlx/fast.cpp | 257 +-- mlx/fast.h | 14 - mlx/fast_primitives.h | 11 +- mlx/ops.cpp | 520 ++++- mlx/ops.h | 12 +- mlx/primitives.cpp | 38 +- mlx/primitives.h | 15 +- python/mlx/nn/layers/embedding.py | 4 +- python/mlx/nn/layers/linear.py | 4 +- python/mlx/nn/layers/quantized.py | 81 +- python/src/ops.cpp | 121 +- python/tests/cuda_skip.py | 2 + python/tests/test_nn.py | 6 + python/tests/test_quantized.py | 258 ++- tests/ops_tests.cpp | 5 +- 28 files changed, 3635 insertions(+), 757 deletions(-) create mode 100644 mlx/backend/metal/kernels/fp4_quantized.h create mode 100644 mlx/backend/metal/kernels/fp4_quantized.metal create mode 100644 mlx/backend/metal/kernels/quantized_utils.h diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst index 1febe960a..4c4ce65ae 100644 --- a/docs/src/dev/custom_metal_kernels.rst +++ b/docs/src/dev/custom_metal_kernels.rst @@ -127,7 +127,7 @@ relying on a copy from ``ensure_row_contiguous``: name="myexp_strided", input_names=["inp"], output_names=["out"], - source=source + source=source, ensure_row_contiguous=False, ) diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index 1c02c4e61..a475131f7 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -1,7 +1,5 @@ // Copyright © 2023 Apple Inc. -#include - #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/backend/cpu/simd/simd.h" @@ -13,6 +11,35 @@ namespace mlx::core { namespace { +const static float MXFP4_LUT[16] = { + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f}; + +template +static inline T dequantize_scale(uint8_t s) { + using FOrI = union { + bfloat16_t f; + uint16_t i; + }; + FOrI out; + out.i = (s == 0 ? 0x40 : (static_cast(s) << 7)); + return static_cast(out.f); +} + inline constexpr short get_pack_factor(int bits, int wsize = 8) { return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); } @@ -407,6 +434,231 @@ void _qmm_dispatch( } } +template +void mxfp4_qmm( + T* result, + const T* x, + const uint32_t* w, + const uint8_t* scales, + int M, + int N, + int K) { + constexpr int group_size = 32; + constexpr int pack_factor = get_pack_factor(4, 8); + constexpr int bytes_per_pack = get_bytes_per_pack(4); + constexpr int packs_in_group = group_size / pack_factor; + + for (int m = 0; m < M; m++) { + const uint8_t* w_local = (const uint8_t*)w; + const uint8_t* scales_local = scales; + + std::fill(result, result + N, 0); + + for (int k = 0; k < K; k++) { + T* result_local = result; + T xi = *x++; + + for (int n = 0; n < N; n += group_size) { + T scale = dequantize_scale(*scales_local++); + for (int ng = 0; ng < packs_in_group; ng++) { + uint8_t wi = *w_local++; +#pragma clang loop unroll(full) + for (int p = 0; p < pack_factor; p++) { + (*result_local++) += + xi * scale * static_cast(MXFP4_LUT[wi & 0xf]); + wi >>= 4; + } + } + } + } + + result += N; + } +} + +template +void mxfp4_qmm_t( + T* result, + const T* x, + const uint32_t* w, + const uint8_t* scales, + int M, + int N, + int K) { + constexpr int group_size = 32; + constexpr int pack_factor = get_pack_factor(4, 8); + constexpr int bytes_per_pack = get_bytes_per_pack(4); + constexpr int packs_in_group = group_size / pack_factor; + + for (int m = 0; m < M; m++) { + const uint8_t* w_local = (const uint8_t*)w; + const uint8_t* scales_local = scales; + + for (int n = 0; n < N; n++) { + const T* x_local = x; + T sum = 0; + for (int k = 0; k < K; k += group_size) { + T scale = dequantize_scale(*scales_local++); + + T gsum = 0; + for (int kw = 0; kw < packs_in_group; kw++) { + uint8_t wi = *w_local++; +#pragma clang loop unroll(full) + for (int p = 0; p < pack_factor; p++) { + gsum += (*x_local++) * static_cast(MXFP4_LUT[wi & 0xf]); + wi >>= 4; + } + } + sum += scale * gsum; + } + *result = sum; + result++; + } + + x += K; + } +} + +template +simd::Simd mxfp4_extract_bits_simd(const uint32_t* w) { + if constexpr (S == 8) { + constexpr std::array shifts_ = {{0, 4, 8, 12, 16, 20, 24, 28}}; + auto shifts(*(simd::Simd*)&shifts_); + auto wi = simd::Simd(*w); + wi = wi >> shifts; + wi = wi & 0xf; + simd::Simd w_out; + for (int i = 0; i < S; ++i) { + w_out[i] = MXFP4_LUT[wi[i]]; + } + return w_out; + } else { + // Appease compiler.. but should never get here + throw std::runtime_error("Unsupported combination for simd qmm."); + } +} + +template +void mxfp4_qmm_t_simd( + T* result, + const T* x, + const uint32_t* w, + const uint8_t* scales, + int M, + int N, + int K) { + constexpr int group_size = 32; + constexpr int pack_factor = 32 / 4; + constexpr int packs_in_group = group_size / pack_factor; + constexpr int S = simd::max_size; + static_assert( + S % pack_factor == 0, "SIMD size must be divisible by pack factor"); + constexpr int packs_per_simd = S / pack_factor; + + for (int m = 0; m < M; m++) { + const uint32_t* w_local = w; + const uint8_t* scales_local = scales; + + for (int n = 0; n < N; n++) { + simd::Simd acc(0); + auto x_local = x; + for (int k = 0; k < K; k += group_size) { + T scale = dequantize_scale(*scales_local++); + + simd::Simd g_acc(0); + for (int kw = 0; kw < packs_in_group; kw += packs_per_simd) { + // Extract bits + auto wf = mxfp4_extract_bits_simd(w_local); + w_local += packs_per_simd; + simd::Simd x_simd = simd::load(x_local); + g_acc = g_acc + x_simd * wf; + x_local += S; + } + acc = acc + scale * g_acc; + } + + *result = T(simd::sum(acc)); + result++; + } + x += K; + } +} + +template +void mxfp4_qmm_dispatch_transpose( + T* result, + const T* x, + const uint32_t* w, + const uint8_t* scales, + int M, + int N, + int K, + bool transposed_w) { + if (transposed_w) { + // the simd size must be a multiple of the number of elements per word + if constexpr (simd::max_size % 8 == 0) { + mxfp4_qmm_t_simd(result, x, w, scales, M, N, K); + } else { + mxfp4_qmm_t(result, x, w, scales, M, N, K); + } + } else { + mxfp4_qmm(result, x, w, scales, M, N, K); + } +} + +template +void mxfp4_qmm_dispatch_typed( + array& out, + const array& x, + const array& w, + const array& scales, + bool transposed_w) { + int K = x.shape(-1); + int M = x.ndim() > 1 ? x.shape(-2) : 1; + int N = out.shape(-1); + int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0; + int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; + int batch_size = x.size() / (K * M); + + auto out_ptr = out.data(); + auto x_ptr = x.data(); + auto w_ptr = w.data(); + auto scales_ptr = scales.data(); + for (int i = 0; i < batch_size; i++) { + mxfp4_qmm_dispatch_transpose( + out_ptr + i * M * N, + x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()), + w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()), + scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()), + M, + N, + K, + transposed_w); + } +} + +void mxfp4_qmm_dispatch( + array& out, + const array& x, + const array& w, + const array& scales, + bool transposed_w) { + switch (x.dtype()) { + case bfloat16: + mxfp4_qmm_dispatch_typed(out, x, w, scales, transposed_w); + break; + case float16: + mxfp4_qmm_dispatch_typed(out, x, w, scales, transposed_w); + break; + case float32: + mxfp4_qmm_dispatch_typed(out, x, w, scales, transposed_w); + break; + default: + throw std::invalid_argument( + "[quantized_matmul] only floating types are supported"); + } +} + template void _bs_qmm_dispatch_typed( array& out, @@ -513,115 +765,198 @@ void _bs_qmm_dispatch( } } +template +void mxfp4_bs_qmm_dispatch_typed( + array& out, + const array& x, + const array& w, + const array& scales, + const array& lhs_indices, + const array& rhs_indices, + bool transposed_w) { + int K = x.shape(-1); + int M = x.shape(-2); + int N = out.shape(-1); + + int w_els = w.shape(-1) * w.shape(-2); + int g_els = scales.shape(-1) * scales.shape(-2); + + auto out_ptr = out.data(); + auto x_ptr = x.data(); + auto w_ptr = w.data(); + auto scales_ptr = scales.data(); + auto lhs_indices_ptr = lhs_indices.data(); + auto rhs_indices_ptr = rhs_indices.data(); + + for (int i = 0; i < lhs_indices.size(); i++) { + int x_idx = lhs_indices_ptr[elem_to_loc( + i, lhs_indices.shape(), lhs_indices.strides())]; + int w_idx = rhs_indices_ptr[elem_to_loc( + i, rhs_indices.shape(), rhs_indices.strides())]; + mxfp4_qmm_dispatch_transpose( + out_ptr + i * M * N, + x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()), + w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()), + scales_ptr + + elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()), + M, + N, + K, + transposed_w); + } +} + +void mxfp4_bs_qmm_dispatch( + array& out, + const array& x, + const array& w, + const array& scales, + const array& lhs_indices, + const array& rhs_indices, + bool transposed_w) { + switch (x.dtype()) { + case float32: + mxfp4_bs_qmm_dispatch_typed( + out, x, w, scales, lhs_indices, rhs_indices, transposed_w); + break; + case float16: + mxfp4_bs_qmm_dispatch_typed( + out, x, w, scales, lhs_indices, rhs_indices, transposed_w); + break; + case bfloat16: + mxfp4_bs_qmm_dispatch_typed( + out, x, w, scales, lhs_indices, rhs_indices, transposed_w); + break; + default: + throw std::invalid_argument( + "[quantized_matmul] only floating types are supported"); + } +} + } // namespace void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 4); - auto& x_pre = inputs[0]; auto& w_pre = inputs[1]; auto& scales_pre = inputs[2]; - auto& biases_pre = inputs[3]; - std::vector temps; - auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) { + auto& encoder = cpu::get_command_encoder(stream()); + auto ensure_row_contiguous = [s = stream(), &encoder](const array& arr) { if (arr.flags().row_contiguous) { return arr; } else { - temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); - copy_cpu(arr, temps.back(), CopyType::General, s); - return temps.back(); + auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {}); + copy_cpu(arr, arr_cpy, CopyType::General, s); + encoder.add_temporary(arr_cpy); + return arr_cpy; } }; auto x = ensure_row_contiguous(x_pre); auto w = ensure_row_contiguous(w_pre); auto scales = ensure_row_contiguous(scales_pre); - auto biases = ensure_row_contiguous(biases_pre); out.set_data(allocator::malloc(out.nbytes())); - auto& encoder = cpu::get_command_encoder(stream()); - encoder.add_temporaries(std::move(temps)); encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(scales); - encoder.set_input_array(biases); encoder.set_output_array(out); - encoder.dispatch([out = array::unsafe_weak_copy(out), - x = array::unsafe_weak_copy(x), - w = array::unsafe_weak_copy(w), - scales = array::unsafe_weak_copy(scales), - biases = array::unsafe_weak_copy(biases), - group_size_ = group_size_, - bits_ = bits_, - transpose_ = transpose_]() mutable { - _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); - }); + if (mode_ == QuantizationMode::Affine) { + auto biases = ensure_row_contiguous(inputs[3]); + encoder.set_input_array(biases); + encoder.dispatch([out = array::unsafe_weak_copy(out), + x = array::unsafe_weak_copy(x), + w = array::unsafe_weak_copy(w), + scales = array::unsafe_weak_copy(scales), + biases = array::unsafe_weak_copy(biases), + group_size_ = group_size_, + bits_ = bits_, + transpose_ = transpose_]() mutable { + _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); + }); + } else { + encoder.dispatch([out = array::unsafe_weak_copy(out), + x = array::unsafe_weak_copy(x), + w = array::unsafe_weak_copy(w), + scales = array::unsafe_weak_copy(scales), + transpose_ = transpose_]() mutable { + mxfp4_qmm_dispatch(out, x, w, scales, transpose_); + }); + } } void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 6); - auto& x_pre = inputs[0]; auto& w_pre = inputs[1]; auto& scales_pre = inputs[2]; - auto& biases_pre = inputs[3]; - auto& lhs_indices = inputs[4]; - auto& rhs_indices = inputs[5]; + auto& lhs_indices = inputs[inputs.size() - 2]; + auto& rhs_indices = inputs[inputs.size() - 1]; - std::vector temps; + auto& encoder = cpu::get_command_encoder(stream()); auto ensure_row_contiguous_last_dims = [s = stream(), - &temps](const array& arr) { + &encoder](const array& arr) { auto stride_0 = arr.strides()[arr.ndim() - 2]; auto stride_1 = arr.strides()[arr.ndim() - 1]; if (stride_0 == arr.shape(-1) && stride_1 == 1) { return arr; } else { - temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {})); - copy_cpu(arr, temps.back(), CopyType::General, s); - return temps.back(); + auto arr_cpy = array(arr.shape(), arr.dtype(), nullptr, {}); + copy_cpu(arr, arr_cpy, CopyType::General, s); + encoder.add_temporary(arr_cpy); + return arr_cpy; } }; auto x = ensure_row_contiguous_last_dims(x_pre); auto w = ensure_row_contiguous_last_dims(w_pre); auto scales = ensure_row_contiguous_last_dims(scales_pre); - auto biases = ensure_row_contiguous_last_dims(biases_pre); out.set_data(allocator::malloc(out.nbytes())); - auto& encoder = cpu::get_command_encoder(stream()); - encoder.add_temporaries(std::move(temps)); encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(scales); - encoder.set_input_array(biases); encoder.set_input_array(lhs_indices); encoder.set_input_array(rhs_indices); encoder.set_output_array(out); - encoder.dispatch([out = array::unsafe_weak_copy(out), - x = array::unsafe_weak_copy(x), - w = array::unsafe_weak_copy(w), - scales = array::unsafe_weak_copy(scales), - biases = array::unsafe_weak_copy(biases), - lhs_indices = array::unsafe_weak_copy(lhs_indices), - rhs_indices = array::unsafe_weak_copy(rhs_indices), - group_size_ = group_size_, - bits_ = bits_, - transpose_ = transpose_]() mutable { - _bs_qmm_dispatch( - out, - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - group_size_, - bits_, - transpose_); - }); + if (mode_ == QuantizationMode::Affine) { + auto biases = ensure_row_contiguous_last_dims(inputs[3]); + encoder.set_input_array(biases); + encoder.dispatch([out = array::unsafe_weak_copy(out), + x = array::unsafe_weak_copy(x), + w = array::unsafe_weak_copy(w), + scales = array::unsafe_weak_copy(scales), + biases = array::unsafe_weak_copy(biases), + lhs_indices = array::unsafe_weak_copy(lhs_indices), + rhs_indices = array::unsafe_weak_copy(rhs_indices), + group_size_ = group_size_, + bits_ = bits_, + transpose_ = transpose_]() mutable { + _bs_qmm_dispatch( + out, + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + group_size_, + bits_, + transpose_); + }); + } else { + encoder.dispatch([out = array::unsafe_weak_copy(out), + x = array::unsafe_weak_copy(x), + w = array::unsafe_weak_copy(w), + scales = array::unsafe_weak_copy(scales), + lhs_indices = array::unsafe_weak_copy(lhs_indices), + rhs_indices = array::unsafe_weak_copy(rhs_indices), + transpose_ = transpose_]() mutable { + mxfp4_bs_qmm_dispatch( + out, x, w, scales, lhs_indices, rhs_indices, transpose_); + }); + } } template @@ -705,7 +1040,7 @@ void dispatch_quantize( w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size()); } -void fast::AffineQuantize::eval_cpu( +void fast::Quantize::eval_cpu( const std::vector& inputs, std::vector& outputs) { auto ensure_row_contiguous = [s = stream()](const array& arr) { @@ -764,7 +1099,7 @@ void fast::AffineQuantize::eval_cpu( } } else { throw std::runtime_error( - "[fast::AffineQuantize::eval_cpu] Only supports floating point inputs"); + "[fast::Quantize::eval_cpu] Only supports floating point inputs"); } }); } diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 008001c50..71c687d85 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -46,10 +46,10 @@ inline array ensure_row_contiguous_matrix( } // namespace -void fast::AffineQuantize::eval_gpu( +void fast::Quantize::eval_gpu( const std::vector& inputs, std::vector& outputs) { - nvtx3::scoped_range r("AffineQuantize::eval_gpu"); + nvtx3::scoped_range r("Quantize::eval_gpu"); auto& s = stream(); auto& d = cu::device(s.device); auto& enc = d.get_command_encoder(s); diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 4069d8c21..70faa1d24 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -108,7 +108,8 @@ if(NOT MLX_METAL_JIT) reduction/reduce_all.h reduction/reduce_col.h reduction/reduce_row.h) - build_kernel(quantized quantized.h ${STEEL_HEADERS}) + build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS}) + build_kernel(fp4_quantized fp4_quantized.h quantized_utils.h ${STEEL_HEADERS}) build_kernel(scan scan.h) build_kernel(softmax softmax.h) build_kernel(logsumexp logsumexp.h) diff --git a/mlx/backend/metal/kernels/fp4_quantized.h b/mlx/backend/metal/kernels/fp4_quantized.h new file mode 100644 index 000000000..b5a8918e4 --- /dev/null +++ b/mlx/backend/metal/kernels/fp4_quantized.h @@ -0,0 +1,1791 @@ +// Copyright © 2025 Apple Inc. + +#include +#include + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; + +template +inline constexpr short get_pack_factor() { + return wsize / 4; +} + +template +inline constexpr short get_bytes_per_pack() { + return wsize / 8; +} + +template +static inline T dequantize_scale(uint8_t s) { + using FOrI = union { + bfloat16_t f; + uint16_t i; + }; + FOrI out; + out.i = (s == 0 ? 0x40 : (static_cast(s) << 7)); + return static_cast(out.f); +} + +template +inline void load_vector(const device T* x, thread U* x_thread) { + for (int i = 0; i < values_per_thread; i += 4) { + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; + } +} + +template +inline void load_vector_safe(const device T* x, thread U* x_thread, int N) { + for (int i = 0; i < N; i += 4) { + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; + } + + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } +} + +constexpr constant static float MXFP4_LUT[16] = { + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f}; + +template +void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) { + if (simd_gid == 0 && simd_lid < 16) { + lut[simd_lid] = static_cast(MXFP4_LUT[simd_lid]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); +} + +template +inline U qdot( + const device uint8_t* w, + const thread U* x_thread, + U scale, + const threadgroup U* lut) { + U accum = 0; + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * lut[ws[i] & 0xf] + + x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] + + x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] + + x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]); + } + return scale * accum; +} + +template +inline U qdot_safe( + const device uint8_t* w, + const thread U* x_thread, + U scale, + const threadgroup U* lut, + int N) { + U accum = 0; + + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * lut[ws[i] & 0xf] + + x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] + + x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] + + x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]); + } + return scale * accum; +} + +template +inline void qouter( + const thread uint8_t* w, + U x, + U scale, + thread U* result, + const threadgroup U* lut) { + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * scale * lut[w[i] & 0xf]; + result[2 * i + 1] += x * scale * lut[(w[i] >> 4) & 0xf]; + } +} + +template +inline void dequantize( + const device uint8_t* w, + U scale, + threadgroup U* w_local, + const threadgroup U* lut) { + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = scale * lut[w[i] & 0xf]; + w_local[2 * i + 1] = scale * lut[(w[i] >> 4) & 0xf]; + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + typename S> +struct QuantizedBlockLoader { + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + + MLX_MTL_CONST short pack_factor = get_pack_factor<8>(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short group_steps = group_size / BCOLS; + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + const device S* scales; + threadgroup T* lut; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device S* scales_, + const int src_ld_, + threadgroup T* dst_, + threadgroup T* lut_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_step_cnt(0), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size), + lut(lut_) { + load_mxfp4_lut(lut, simd_group_id, simd_lane_id); + } + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = dequantize_scale(*scales); + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, dst + i * pack_factor, lut); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = dequantize_scale(*scales); + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + dst + i * pack_factor, + lut); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales++; + } + } else { + scales++; + } + } else { + scales += group_stride; + } + } +}; + +template +METAL_FUNC void mxfp4_qmv_quad_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { + constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; + constexpr int pack_factor = 8; + constexpr int values_per_thread = D / QUAD_SIZE; + constexpr int packs_per_thread = values_per_thread / pack_factor; + constexpr int scale_step_per_thread = group_size / values_per_thread; + constexpr int results_per_quadgroup = 8; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_quadgroup] = {0}; + load_mxfp4_lut(lut, simd_gid, simd_lid); + + // Adjust positions + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid; + + w += out_row * in_vec_size_w + quad_lid * packs_per_thread; + scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; + x += tid.x * in_vec_size + quad_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + load_vector(x, x_thread); + + for (int row = 0; row < results_per_quadgroup; row++) { + auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); + const device S* sl = scales + row * in_vec_size_g * quads_per_simd; + + U s = dequantize_scale(sl[0]); + if (row * quads_per_simd + out_row < out_vec_size) { + result[row] += qdot(wl, x_thread, s, lut); + } + } + + for (int row = 0; row < results_per_quadgroup; row++) { + result[row] = quad_sum(result[row]); + if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { + y[row * quads_per_simd] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void mxfp4_qmv_fast_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { + constexpr int packs_per_thread = 2; + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = get_pack_factor<32>(); + constexpr int bytes_per_pack = get_bytes_per_pack<32>(); + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + load_mxfp4_lut(lut, simd_gid, simd_lid); + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + for (int k = 0; k < in_vec_size; k += block_size) { + load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s, lut); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void mxfp4_qmv_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int packs_per_thread = 1; + constexpr int pack_factor = get_pack_factor<32>(); + constexpr int bytes_per_pack = get_bytes_per_pack<32>(); + + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + load_mxfp4_lut(lut, simd_gid, simd_lid); + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); + + if (out_row >= out_vec_size) { + return; + } + + // In this case we need to properly guard all our reads because there isn't + // even 1 tile in the matrix + if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { + ws += + out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + load_vector(x, x_thread); + + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + S s = sl[0]; + result[row] += qdot(wl, x_thread, s, lut); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + load_vector_safe(x, x_thread, remaining); + + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s, lut); + } + } + + for (int row = 0; out_row + row < out_vec_size; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } + + // In this case the last tile is moved back to redo some output values + else { + ws += used_out_row * in_vec_size_w + + simd_lid * packs_per_thread * bytes_per_pack; + scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + used_out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s, lut); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + load_vector_safe(x, x_thread, remaining); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += + qdot_safe(wl, x_thread, s, lut, remaining); + } + } + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } +} + +template +METAL_FUNC void mxfp4_qvm_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const int in_vec_size, + const int out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { + constexpr int num_simdgroups = 2; + constexpr int pack_factor = get_pack_factor<32>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int tn = 32 / pack_factor; + constexpr int block_size = SIMD_SIZE; + + using W_T = uint32_t; + const device W_T* ws = (const device W_T*)w; + + typedef float U; + typedef struct { + W_T wi[tn * bytes_per_pack]; + } vec_w; + + thread vec_w w_local; + thread U result[tn * pack_factor] = {0}; + thread U scale = 0; + thread U x_local = 0; + + load_mxfp4_lut(lut, simd_gid, simd_lid); + + // Adjust positions + const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; + const int out_vec_size_g = out_vec_size / group_size; + int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid); + ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; + scales += out_col / group_size + simd_lid * out_vec_size_g; + x += tid.x * in_vec_size + simd_lid; + y += tid.x * out_vec_size + out_col; + + if (out_col >= out_vec_size) { + return; + } + + // Loop over in_vec in blocks of block_size + int remaining = in_vec_size % block_size; + if (remaining == 0) { + for (int i = 0; i < in_vec_size; i += block_size) { + x_local = *x; + scale = dequantize_scale(*scales); + w_local = *((device vec_w*)ws); + qouter( + (thread uint8_t*)&w_local, x_local, scale, result, lut); + + x += block_size; + scales += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + } else { + for (int i = block_size; i < in_vec_size; i += block_size) { + x_local = *x; + scale = dequantize_scale(*scales); + w_local = *((device vec_w*)ws); + + qouter( + (thread uint8_t*)&w_local, x_local, scale, result, lut); + + x += block_size; + scales += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + if (static_cast(simd_lid) < remaining) { + x_local = *x; + scale = dequantize_scale(*scales); + w_local = *((device vec_w*)ws); + } else { + x_local = 0; + scale = 0; + } + qouter( + (thread uint8_t*)&w_local, x_local, scale, result, lut); + } + +// Accumulate in the simdgroup +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + result[k] = simd_sum(result[k]); + } + + // Store the result + if (simd_lid == 0) { +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + y[k] = static_cast(result[k]); + } + } +} + +template < + typename T, + const int group_size, + const bool aligned_N, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void mxfp4_qmm_t_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup T* lut) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + S>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + const short num_outs = min(BN, N - y_col); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, K, Ws, lut, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM || num_outs < BN) { + mma_op.store_result_safe(y, N, short2(num_outs, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template < + typename T, + const int group_size, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void mxfp4_qmm_n_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup T* lut) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = mlx::steel:: + BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size, + S>; + + auto wl = (const device uint8_t*)w; + + // Set the block + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + x += y_row * static_cast(K); + wl += y_col * bytes_per_pack / pack_factor; + scales += y_col / group_size; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, N, Ws, lut, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, num_els)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, BM)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM) { + mma_op.store_result_safe(y, N, short2(BN, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device S*& scales, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device S*& scales, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template +[[kernel]] void mxfp4_qmv_quad( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + threadgroup float lut[16]; + mxfp4_qmv_quad_impl( + w, + scales, + x, + y, + in_vec_size, + out_vec_size, + tid, + quad_gid, + quad_lid, + simd_gid, + simd_lid, + lut); +} + +template +[[kernel]] void mxfp4_qmv_fast( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + threadgroup float lut[16]; + mxfp4_qmv_fast_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} + +template +[[kernel]] void mxfp4_qmv( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + threadgroup float lut[16]; + mxfp4_qmv_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} + +template +[[kernel]] void mxfp4_qvm( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + threadgroup float lut[16]; + mxfp4_qvm_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} + +template +[[kernel]] void mxfp4_qvm_split_k( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& final_block_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + + // When (in_vec_size % split_k != 0) the final block needs to be smaller + int in_vec_size_adj = + tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; + + threadgroup float lut[16]; + mxfp4_qvm_impl( + w, + scales, + x, + y, + in_vec_size_adj, + out_vec_size, + tid, + simd_gid, + simd_lid, + lut); +} + +template < + typename T, + const int group_size, + const bool aligned_N, + const bool batched, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_qmm_t( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + threadgroup T lut[16]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + mxfp4_qmm_t_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + const int group_size, + const bool batched, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_qmm_n( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + threadgroup T lut[16]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + + mxfp4_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template +[[kernel]] void mxfp4_gather_qmv_fast( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + threadgroup float lut[16]; + mxfp4_qmv_fast_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} + +template +[[kernel]] void mxfp4_gather_qmv( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + threadgroup float lut[16]; + mxfp4_qmv_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} + +template +[[kernel]] void mxfp4_gather_qvm( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + threadgroup float lut[16]; + mxfp4_qvm_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} + +template < + typename T, + const int group_size, + const bool aligned_N, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_gather_qmm_t( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + threadgroup T lut[16]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + mxfp4_qmm_t_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + const int group_size, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_gather_qmm_n( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + threadgroup T lut[16]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + mxfp4_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + int group_size, + typename S, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose> +[[kernel]] void mxfp4_gather_qmm_rhs( + const device T* x, + const device uint32_t* w, + const device S* scales, + const device uint32_t* indices, + device T* y, + const constant int& M, + const constant int& N, + const constant int& K, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T lut[16]; + + using mma_t = mlx::steel::BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + false, + transpose, + BK_padded, + transpose ? BK_padded : BN_padded>; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size, + S>; + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_x = short2(k_remain, tgp_bm); + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + transpose ? K : N, + Ws, + lut, + simd_group_id, + simd_lane_id); + + // Matrices are all aligned check nothing + if (align_M && align_N) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } else { + // Tile aligned so check outside of the hot loop + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} diff --git a/mlx/backend/metal/kernels/fp4_quantized.metal b/mlx/backend/metal/kernels/fp4_quantized.metal new file mode 100644 index 000000000..c982b4bf4 --- /dev/null +++ b/mlx/backend/metal/kernels/fp4_quantized.metal @@ -0,0 +1,127 @@ +// Copyright © 2025 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/quantized_utils.h" +#include "mlx/backend/metal/kernels/fp4_quantized.h" + +#define instantiate_quantized(name, type) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4", \ + name, \ + type, \ + 32, \ + uint8_t) + +#define instantiate_quantized_batched(name, type, batched) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4_batch_" #batched, \ + name, \ + type, \ + 32, \ + batched, \ + uint8_t) + +#define instantiate_quantized_aligned(name, type, aligned) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4_alN_" #aligned, \ + name, \ + type, \ + 32, \ + aligned, \ + uint8_t) + +#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \ + name, \ + type, \ + 32, \ + aligned, \ + batched, \ + uint8_t) + +#define instantiate_quantized_quad(name, type, D, batched) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \ + name, \ + type, \ + 32, \ + D, \ + batched, \ + uint8_t) + +#define instantiate_quantized_split_k(name, type, split_k) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4_spk_" #split_k, \ + name, \ + type, \ + 32, \ + split_k, \ + uint8_t) + +#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \ + func, \ + type, \ + 32, \ + uint8_t, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + transpose) + +#define instantiate_quantized_batched_wrap(name, type) \ + instantiate_quantized_batched(name, type, 1) \ + instantiate_quantized_batched(name, type, 0) + +#define instantiate_quantized_all_batched(type) \ + instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \ + instantiate_quantized_batched_wrap(mxfp4_qmv, type) \ + instantiate_quantized_batched_wrap(mxfp4_qvm, type) \ + instantiate_quantized_batched_wrap(mxfp4_qmm_n, type) + +#define instantiate_quantized_all_single(type) \ + instantiate_quantized(mxfp4_gather_qmv_fast, type) \ + instantiate_quantized(mxfp4_gather_qmv, type) \ + instantiate_quantized(mxfp4_gather_qvm, type) \ + instantiate_quantized(mxfp4_gather_qmm_n, type) + +#define instantiate_quantized_all_aligned(type) \ + instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \ + instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \ + instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \ + instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \ + instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \ + instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0) + +#define instantiate_quantized_all_quad(type) \ + instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \ + instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \ + instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \ + instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0) + +#define instantiate_quantized_all_splitk(type) \ + instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \ + instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32) + +#define instantiate_quantized_all_rhs(type) \ + instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \ + instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false) + +#define instantiate_quantized_types(type) \ + instantiate_quantized_all_batched(type) \ + instantiate_quantized_all_quad(type) \ + instantiate_quantized_all_splitk(type) \ + instantiate_quantized_all_single(type) \ + instantiate_quantized_all_aligned(type) \ + instantiate_quantized_all_rhs(type) + +instantiate_quantized_types(float) +instantiate_quantized_types(bfloat16_t) +instantiate_quantized_types(float16_t) + // clang-format on diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 0a40cec00..bf639814b 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -1434,7 +1434,7 @@ METAL_FUNC void adjust_matrix_offsets( } template -[[kernel]] void qmv_quad( +[[kernel]] void affine_qmv_quad( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1486,7 +1486,7 @@ template } template -[[kernel]] void qmv_fast( +[[kernel]] void affine_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1538,7 +1538,7 @@ template } template -[[kernel]] void qmv( +[[kernel]] void affine_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1590,7 +1590,7 @@ template } template -[[kernel]] void qvm( +[[kernel]] void affine_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1642,7 +1642,7 @@ template } template -[[kernel]] void qvm_split_k( +[[kernel]] void affine_qvm_split_k( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1706,7 +1706,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void qmm_t( +[[kernel]] void affine_qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1764,7 +1764,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void qmm_n( +[[kernel]] void affine_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1817,7 +1817,7 @@ template < } template -[[kernel]] void gather_qmv_fast( +[[kernel]] void affine_gather_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1879,7 +1879,7 @@ template } template -[[kernel]] void gather_qmv( +[[kernel]] void affine_gather_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1941,7 +1941,7 @@ template } template -[[kernel]] void gather_qvm( +[[kernel]] void affine_gather_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -2010,7 +2010,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void gather_qmm_t( +[[kernel]] void affine_gather_qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -2077,7 +2077,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void gather_qmm_n( +[[kernel]] void affine_gather_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -2138,92 +2138,6 @@ template < w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } -template -METAL_FUNC void gemm_loop_aligned( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const int k_iterations) { - for (int k = 0; k < k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup memory - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } -} - -template < - bool rows_aligned, - bool cols_aligned, - bool transpose, - typename T, - typename mma_t, - typename loader_a_t, - typename loader_b_t> -METAL_FUNC void gemm_loop_unaligned( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const int k_iterations, - const short tgp_bm, - const short tgp_bn, - const short tgp_bk) { - for (int k = 0; k < k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup memory - if (rows_aligned) { - loader_a.load_unsafe(); - } else { - loader_a.load_safe(short2(tgp_bk, tgp_bm)); - } - if (cols_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe( - transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } -} - -template -METAL_FUNC void gemm_loop_finalize( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const short2 tile_a, - const short2 tile_b) { - loader_a.load_safe(tile_a); - loader_b.load_safe(tile_b); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); -} - template < typename T, int group_size, @@ -2234,7 +2148,7 @@ template < int WM, int WN, bool transpose> -[[kernel]] void gather_qmm_rhs( +[[kernel]] void affine_gather_qmm_rhs( const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(1)]], const device T* scales [[buffer(2)]], diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index de83cb657..f734b9bce 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -3,6 +3,7 @@ // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/quantized_utils.h" #include "mlx/backend/metal/kernels/quantized.h" #define instantiate_quantized(name, type, group_size, bits) \ @@ -79,40 +80,40 @@ instantiate_quantized_batched(name, type, group_size, bits, 0) #define instantiate_quantized_all_batched(type, group_size, bits) \ - instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \ - instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \ - instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \ - instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits) + instantiate_quantized_batched_wrap(affine_qmv_fast, type, group_size, bits) \ + instantiate_quantized_batched_wrap(affine_qmv, type, group_size, bits) \ + instantiate_quantized_batched_wrap(affine_qvm, type, group_size, bits) \ + instantiate_quantized_batched_wrap(affine_qmm_n, type, group_size, bits) #define instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized(affine_quantize, type, group_size, bits) \ instantiate_quantized(affine_dequantize, type, group_size, bits) \ - instantiate_quantized(gather_qmv_fast, type, group_size, bits) \ - instantiate_quantized(gather_qmv, type, group_size, bits) \ - instantiate_quantized(gather_qvm, type, group_size, bits) \ - instantiate_quantized(gather_qmm_n, type, group_size, bits) + instantiate_quantized(affine_gather_qmv_fast, type, group_size, bits) \ + instantiate_quantized(affine_gather_qmv, type, group_size, bits) \ + instantiate_quantized(affine_gather_qvm, type, group_size, bits) \ + instantiate_quantized(affine_gather_qmm_n, type, group_size, bits) #define instantiate_quantized_all_aligned(type, group_size, bits) \ - instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \ - instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \ - instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \ - instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \ - instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \ - instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 0) + instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, true) \ + instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, false) \ + instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 1) \ + instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 0) \ + instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 1) \ + instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 0) #define instantiate_quantized_all_quad(type, group_size, bits) \ - instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \ - instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \ - instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \ - instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0) + instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 1) \ + instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 0) \ + instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 1) \ + instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 0) #define instantiate_quantized_all_splitk(type, group_size, bits) \ - instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \ - instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32) + instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 8) \ + instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 32) #define instantiate_quantized_all_rhs(type, group_size, bits) \ - instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \ - instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false) + instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \ + instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false) #define instantiate_quantized_funcs(type, group_size, bits) \ instantiate_quantized_all_single(type, group_size, bits) \ diff --git a/mlx/backend/metal/kernels/quantized_utils.h b/mlx/backend/metal/kernels/quantized_utils.h new file mode 100644 index 000000000..38253f8fe --- /dev/null +++ b/mlx/backend/metal/kernels/quantized_utils.h @@ -0,0 +1,90 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +template +METAL_FUNC void gemm_loop_aligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template < + bool rows_aligned, + bool cols_aligned, + bool transpose, + typename T, + typename mma_t, + typename loader_a_t, + typename loader_b_t> +METAL_FUNC void gemm_loop_unaligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations, + const short tgp_bm, + const short tgp_bn, + const short tgp_bk) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + if (rows_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(short2(tgp_bk, tgp_bm)); + } + if (cols_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe( + transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template +METAL_FUNC void gemm_loop_finalize( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const short2 tile_a, + const short2 tile_b) { + loader_a.load_safe(tile_a); + loader_b.load_safe(tile_b); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); +} diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 999825043..f8bc0342e 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -1,7 +1,5 @@ // Copyright © 2023-2024 Apple Inc. -#include - #include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/compiled.h" #include "mlx/backend/gpu/copy.h" @@ -99,7 +97,7 @@ inline int add_strides_and_shapes( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, int offset) { if (skip) { return 0; @@ -109,16 +107,18 @@ inline int add_strides_and_shapes( int x_batch_ndims = x.ndim() - 2; int w_batch_ndims = w.ndim() - 2; - compute_encoder.set_bytes(x_batch_ndims, offset); - compute_encoder.set_vector_bytes(x.shape(), offset + 1); - compute_encoder.set_vector_bytes(x.strides(), offset + 2); - compute_encoder.set_bytes(w_batch_ndims, offset + 3); - compute_encoder.set_vector_bytes(w.shape(), offset + 4); - compute_encoder.set_vector_bytes(w.strides(), offset + 5); - compute_encoder.set_vector_bytes(scales.strides(), offset + 6); - compute_encoder.set_vector_bytes(biases.strides(), offset + 7); + compute_encoder.set_bytes(x_batch_ndims, offset++); + compute_encoder.set_vector_bytes(x.shape(), offset++); + compute_encoder.set_vector_bytes(x.strides(), offset++); + compute_encoder.set_bytes(w_batch_ndims, offset++); + compute_encoder.set_vector_bytes(w.shape(), offset++); + compute_encoder.set_vector_bytes(w.strides(), offset++); + compute_encoder.set_vector_bytes(scales.strides(), offset++); + if (biases) { + compute_encoder.set_vector_bytes(biases->strides(), offset++); + } - return 8; + return offset; } inline int add_gather_strides_and_shapes( @@ -130,12 +130,12 @@ inline int add_gather_strides_and_shapes( lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()}); int ndims = shape.size(); - compute_encoder.set_bytes(ndims, offset); - compute_encoder.set_vector_bytes(shape, offset + 1); - compute_encoder.set_vector_bytes(strides[0], offset + 2); - compute_encoder.set_vector_bytes(strides[1], offset + 3); + compute_encoder.set_bytes(ndims, offset++); + compute_encoder.set_vector_bytes(shape, offset++); + compute_encoder.set_vector_bytes(strides[0], offset++); + compute_encoder.set_vector_bytes(strides[1], offset++); - return 4; + return offset; } } // namespace @@ -144,7 +144,7 @@ void qmv_quad( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, array& out, int group_size, int bits, @@ -152,7 +152,8 @@ void qmv_quad( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int B = out.size() / M / N; constexpr int quads_per_simd = 8; @@ -165,9 +166,10 @@ void qmv_quad( std::string kname; kname.reserve(64); std::string type_string = get_type_string(x.dtype()); + concatenate( kname, - "qmv_quad_", + mode + "_qmv_quad_", type_string, "_gs_", group_size, @@ -177,20 +179,23 @@ void qmv_quad( K, B > 1 ? "_batch_1" : "_batch_0"); auto template_def = get_template_definition( - kname, "qmv_quad", type_string, group_size, bits, K, B > 1); + kname, mode + "_qmv_quad", type_string, group_size, bits, K, B > 1); auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_output_array(out, 4); - compute_encoder.set_bytes(K, 5); - compute_encoder.set_bytes(N, 6); - add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -199,7 +204,7 @@ void qmv( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, array& out, int group_size, int bits, @@ -207,7 +212,8 @@ void qmv( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int B = out.size() / M / N; int bn = 8; @@ -219,9 +225,10 @@ void qmv( kname.reserve(64); std::string type_string = get_type_string(x.dtype()); bool fast = N % bn == 0 && K % 512 == 0; + concatenate( kname, - fast ? "qmv_fast_" : "qmv_", + mode + (fast ? "_qmv_fast_" : "_qmv_"), type_string, "_gs_", group_size, @@ -229,20 +236,28 @@ void qmv( bits, B > 1 ? "_batch_1" : "_batch_0"); auto template_def = get_template_definition( - kname, fast ? "qmv_fast" : "qmv", type_string, group_size, bits, B > 1); + kname, + mode + (fast ? "_qmv_fast" : "_qmv"), + type_string, + group_size, + bits, + B > 1); auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_output_array(out, 4); - compute_encoder.set_bytes(K, 5); - compute_encoder.set_bytes(N, 6); - add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -251,7 +266,7 @@ void qvm_split_k( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, array& out, int group_size, int bits, @@ -259,7 +274,8 @@ void qvm_split_k( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int split_k = K > 8192 ? 32 : 8; int split_D = (K + split_k - 1) / split_k; int B = out.size() / M / N; @@ -283,7 +299,6 @@ void qvm_split_k( auto w_shape = w.shape(); auto w_strides = w.strides(); auto s_strides = scales.strides(); - auto b_strides = biases.strides(); // Add split_k dim with reshapes x_shape.insert(x_shape.end() - 2, split_k); @@ -297,7 +312,6 @@ void qvm_split_k( w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1)); w_batch_ndims += 1; s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1)); - b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1)); int final_block_size = K - (split_k - 1) * split_D; @@ -315,7 +329,7 @@ void qvm_split_k( kname.reserve(64); concatenate( kname, - "qvm_split_k_", + mode + "_qvm_split_k_", type_string, "_gs_", group_size, @@ -324,30 +338,37 @@ void qvm_split_k( "_spk_", split_k); auto template_def = get_template_definition( - kname, "qvm_split_k", type_string, group_size, bits, split_k); + kname, mode + "_qvm_split_k", type_string, group_size, bits, split_k); // Encode and dispatch kernel auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_output_array(intermediate, 4); - compute_encoder.set_bytes(split_D, 5); - compute_encoder.set_bytes(N, 6); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_output_array(intermediate, c++); + compute_encoder.set_bytes(split_D, c++); + compute_encoder.set_bytes(N, c++); - compute_encoder.set_bytes(x_batch_ndims, 7); - compute_encoder.set_vector_bytes(x_shape, 8); - compute_encoder.set_vector_bytes(x_strides, 9); - compute_encoder.set_bytes(w_batch_ndims, 10); - compute_encoder.set_vector_bytes(w_shape, 11); - compute_encoder.set_vector_bytes(w_strides, 12); - compute_encoder.set_vector_bytes(s_strides, 13); - compute_encoder.set_vector_bytes(b_strides, 14); - compute_encoder.set_bytes(final_block_size, 15); + compute_encoder.set_bytes(x_batch_ndims, c++); + compute_encoder.set_vector_bytes(x_shape, c++); + compute_encoder.set_vector_bytes(x_strides, c++); + compute_encoder.set_bytes(w_batch_ndims, c++); + compute_encoder.set_vector_bytes(w_shape, c++); + compute_encoder.set_vector_bytes(w_strides, c++); + compute_encoder.set_vector_bytes(s_strides, c++); + if (biases) { + auto b_strides = biases->strides(); + b_strides.insert(b_strides.end() - 2, split_D * biases->shape(-1)); + compute_encoder.set_vector_bytes(b_strides, c++); + } + compute_encoder.set_bytes(final_block_size, c++); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); @@ -364,7 +385,7 @@ void qvm( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, array& out, int group_size, int bits, @@ -372,7 +393,8 @@ void qvm( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int B = out.size() / M / N; int bn = 64; @@ -385,7 +407,7 @@ void qvm( std::string type_string = get_type_string(x.dtype()); concatenate( kname, - "qvm_", + mode + "_qvm_", type_string, "_gs_", group_size, @@ -393,20 +415,23 @@ void qvm( bits, B > 1 ? "_batch_1" : "_batch_0"); auto template_def = get_template_definition( - kname, "qvm", type_string, group_size, bits, B > 1); + kname, mode + "_qvm", type_string, group_size, bits, B > 1); auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_output_array(out, 4); - compute_encoder.set_bytes(K, 5); - compute_encoder.set_bytes(N, 6); - add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -415,7 +440,7 @@ void qmm( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, array& out, bool transpose, int group_size, @@ -424,7 +449,8 @@ void qmm( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int B = out.size() / M / N; int wm = 2; @@ -441,7 +467,7 @@ void qmm( std::string type_string = get_type_string(x.dtype()); concatenate( kname, - transpose ? "qmm_t_" : "qmm_n_", + mode + (transpose ? "_qmm_t_" : "_qmm_n_"), type_string, "_gs_", group_size, @@ -452,25 +478,34 @@ void qmm( std::string template_def; if (transpose) { template_def = get_template_definition( - kname, "qmm_t", type_string, group_size, bits, aligned, batched); + kname, + mode + "_qmm_t", + type_string, + group_size, + bits, + aligned, + batched); } else { template_def = get_template_definition( - kname, "qmm_n", type_string, group_size, bits, batched); + kname, mode + "_qmm_n", type_string, group_size, bits, batched); } auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_output_array(out, 4); - compute_encoder.set_bytes(K, 5); - compute_encoder.set_bytes(N, 6); - compute_encoder.set_bytes(M, 7); - add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 8); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + compute_encoder.set_bytes(M, c++); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -479,7 +514,7 @@ void gather_qmm( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, const array& lhs_indices, const array& rhs_indices, array& out, @@ -490,7 +525,8 @@ void gather_qmm( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int B = out.size() / M / N; int wm = 2; @@ -507,7 +543,7 @@ void gather_qmm( std::string type_string = get_type_string(x.dtype()); concatenate( kname, - transpose ? "gather_qmm_t_" : "gather_qmm_n_", + mode + (transpose ? "_gather_qmm_t_" : "_gather_qmm_n_"), type_string, "_gs_", group_size, @@ -517,30 +553,31 @@ void gather_qmm( std::string template_def; if (transpose) { template_def = get_template_definition( - kname, "gather_qmm_t", type_string, group_size, bits, aligned); + kname, mode + "_gather_qmm_t", type_string, group_size, bits, aligned); } else { template_def = get_template_definition( - kname, "gather_qmm_n", type_string, group_size, bits); + kname, mode + "_gather_qmm_n", type_string, group_size, bits); } auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_input_array(lhs_indices, 4); - compute_encoder.set_input_array(rhs_indices, 5); - compute_encoder.set_output_array(out, 6); - compute_encoder.set_bytes(K, 7); - compute_encoder.set_bytes(N, 8); - compute_encoder.set_bytes(M, 9); - int n = - add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 10); - add_gather_strides_and_shapes( - compute_encoder, lhs_indices, rhs_indices, 10 + n); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_input_array(lhs_indices, c++); + compute_encoder.set_input_array(rhs_indices, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + compute_encoder.set_bytes(M, c++); + c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c); + add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -549,7 +586,7 @@ void gather_qmv( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, const array& lhs_indices, const array& rhs_indices, array& out, @@ -559,7 +596,8 @@ void gather_qmv( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int B = out.size() / M / N; int bn = 8; @@ -573,7 +611,7 @@ void gather_qmv( bool fast = N % bn == 0 && K % 512 == 0; concatenate( kname, - fast ? "gather_qmv_fast_" : "gather_qmv_", + mode + (fast ? "_gather_qmv_fast_" : "_gather_qmv_"), type_string, "_gs_", group_size, @@ -581,7 +619,7 @@ void gather_qmv( bits); auto template_def = get_template_definition( kname, - fast ? "gather_qmv_fast" : "gather_qmv", + mode + (fast ? "_gather_qmv_fast" : "_gather_qmv"), type_string, group_size, bits); @@ -590,19 +628,20 @@ void gather_qmv( auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_input_array(lhs_indices, 4); - compute_encoder.set_input_array(rhs_indices, 5); - compute_encoder.set_output_array(out, 6); - compute_encoder.set_bytes(K, 7); - compute_encoder.set_bytes(N, 8); - int n = - add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9); - add_gather_strides_and_shapes( - compute_encoder, lhs_indices, rhs_indices, 9 + n); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_input_array(lhs_indices, c++); + compute_encoder.set_input_array(rhs_indices, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c); + add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -611,7 +650,7 @@ void gather_qvm( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, const array& lhs_indices, const array& rhs_indices, array& out, @@ -621,7 +660,8 @@ void gather_qvm( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int B = out.size() / M / N; int bn = 64; @@ -633,27 +673,34 @@ void gather_qvm( kname.reserve(64); std::string type_string = get_type_string(x.dtype()); concatenate( - kname, "gather_qvm_", type_string, "_gs_", group_size, "_b_", bits); + kname, + mode + "_gather_qvm_", + type_string, + "_gs_", + group_size, + "_b_", + bits); auto template_def = get_template_definition( - kname, "gather_qvm", type_string, group_size, bits); + kname, mode + "_gather_qvm", type_string, group_size, bits); auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_input_array(lhs_indices, 4); - compute_encoder.set_input_array(rhs_indices, 5); - compute_encoder.set_output_array(out, 6); - compute_encoder.set_bytes(K, 7); - compute_encoder.set_bytes(N, 8); - int n = - add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9); - add_gather_strides_and_shapes( - compute_encoder, lhs_indices, rhs_indices, 9 + n); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_input_array(lhs_indices, c++); + compute_encoder.set_input_array(rhs_indices, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c++); + add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -662,7 +709,7 @@ void gather_qmm_rhs( const array& x_, const array& w_, const array& scales_, - const array& biases_, + const std::optional& biases_, const array& indices_, array& out, bool transpose, @@ -672,7 +719,8 @@ void gather_qmm_rhs( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string mode) { // Start by normalizing the indices array indices = ensure_row_contiguous(indices_, d, s); @@ -697,7 +745,6 @@ void gather_qmm_rhs( array x = broadcast_with_indices(x_); array w = ensure_row_contiguous(w_, d, s); array scales = ensure_row_contiguous(scales_, d, s); - array biases = ensure_row_contiguous(biases_, d, s); // TODO: Tune the block sizes int bm = 16, bn = 32, bk = 32; @@ -713,7 +760,7 @@ void gather_qmm_rhs( std::string type_string = get_type_string(x.dtype()); concatenate( kname, - transpose ? "gather_qmm_rhs_nt_" : "gather_qmm_rhs_nn_", + mode + (transpose ? "_gather_qmm_rhs_nt_" : "_gather_qmm_rhs_nn_"), type_string, "_gs_", group_size, @@ -770,15 +817,19 @@ void gather_qmm_rhs( MTL::Size group_dims(32, wn, wm); MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1); - compute_encoder.set_input_array(x, 0); - compute_encoder.set_input_array(w, 1); - compute_encoder.set_input_array(scales, 2); - compute_encoder.set_input_array(biases, 3); - compute_encoder.set_input_array(indices, 4); - compute_encoder.set_output_array(out, 5); - compute_encoder.set_bytes(M, 6); - compute_encoder.set_bytes(N, 7); - compute_encoder.set_bytes(K, 8); + int c = 0; + compute_encoder.set_input_array(x, c++); + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases_) { + array biases = ensure_row_contiguous(*biases_, d, s); + compute_encoder.set_input_array(biases, c++); + } + compute_encoder.set_input_array(indices, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(M, c++); + compute_encoder.set_bytes(N, c++); + compute_encoder.set_bytes(K, c++); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -794,7 +845,10 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { array x = ensure_row_contiguous_matrix(inputs[0], d, s); array w = ensure_row_contiguous_matrix(inputs[1], d, s); array scales = ensure_row_contiguous_matrix(inputs[2], d, s); - array biases = ensure_row_contiguous_matrix(inputs[3], d, s); + std::optional biases = std::nullopt; + if (inputs.size() == 4) { + biases = ensure_row_contiguous_matrix(inputs[3], d, s); + } // Extract the matmul shapes bool non_batched = w.ndim() == 2 && x.flags().row_contiguous; @@ -803,7 +857,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { int N = out.shape(-1); int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4; - + auto mode = quantization_mode_to_string(mode_); // It is a matrix matrix product. if (M >= vector_limit) { qmm(x, @@ -818,30 +872,33 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { N, K, d, - s); + s, + mode); return; } // It is a qmv with a small inner dimension so route to qmv_quad kernel if (transpose_ && (K == 128 || K == 64) && is_power_of_2(bits_)) { - qmv_quad(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + qmv_quad( + x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode); return; } // Run of the mill qmv if (transpose_) { - qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode); return; } // Run of the mill qvm if (K < 1024) { - qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode); return; } // Qvm with large dimension so route to a split K kernel for more parallelism - qvm_split_k(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + qvm_split_k( + x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode); return; } @@ -854,9 +911,12 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { array x = ensure_row_contiguous_matrix(inputs[0], d, s); array w = ensure_row_contiguous_matrix(inputs[1], d, s); array scales = ensure_row_contiguous_matrix(inputs[2], d, s); - array biases = ensure_row_contiguous_matrix(inputs[3], d, s); - const array& lhs_indices = inputs[4]; - const array& rhs_indices = inputs[5]; + std::optional biases = std::nullopt; + if (inputs.size() == 6) { + biases = ensure_row_contiguous_matrix(inputs[3], d, s); + } + const array& lhs_indices = inputs[inputs.size() - 2]; + const array& rhs_indices = inputs[inputs.size() - 1]; int K = x.shape(-1); int M = x.shape(-2); @@ -864,6 +924,7 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { int B = out.size() / M / N; int E = w.size() / w.shape(-1) / w.shape(-2); int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4; + auto mode = quantization_mode_to_string(mode_); // We are walking x in order and w is also in order so we can batch up the // matmuls and reuse reading x and w. @@ -884,7 +945,8 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { N, K, d, - s); + s, + mode); return; } @@ -905,7 +967,8 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { N, K, d, - s); + s, + mode); return; } @@ -924,7 +987,8 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { N, K, d, - s); + s, + mode); return; } @@ -942,10 +1006,11 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { N, K, d, - s); + s, + mode); } -void fast::AffineQuantize::eval_gpu( +void fast::Quantize::eval_gpu( const std::vector& inputs, std::vector& outputs) { auto& w_pre = inputs[0]; diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 09e6c4ef3..dba82c6dc 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -129,7 +129,7 @@ NO_CPU(Inverse) NO_CPU(View) namespace fast { -NO_CPU_MULTI(AffineQuantize) +NO_CPU_MULTI(Quantize) } // namespace fast namespace distributed { diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index dfe5b57f1..22a0c8acc 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -154,7 +154,7 @@ NO_GPU_USE_FALLBACK(RMSNorm) NO_GPU_MULTI(RMSNormVJP) NO_GPU_USE_FALLBACK(RoPE) NO_GPU(ScaledDotProductAttention) -NO_GPU_MULTI(AffineQuantize) +NO_GPU_MULTI(Quantize) NO_GPU_MULTI(CustomKernel) } // namespace fast diff --git a/mlx/export.cpp b/mlx/export.cpp index 7099f4864..19944dfc4 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -335,7 +335,7 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(Cholesky), SERIALIZE_PRIMITIVE(Eig), SERIALIZE_PRIMITIVE(Eigh), - SERIALIZE_PRIMITIVE(AffineQuantize), + SERIALIZE_PRIMITIVE(Quantize), SERIALIZE_PRIMITIVE(RMSNorm), SERIALIZE_PRIMITIVE(RMSNormVJP), SERIALIZE_PRIMITIVE(LayerNorm), diff --git a/mlx/fast.cpp b/mlx/fast.cpp index b8d622253..befc9f80c 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -762,255 +762,14 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_; } -array pack_and_quantize( - array& packed_w, - const array& scales, - const array& biases, - int bits, - const Stream& s) { - int el_per_int = 32 / bits; - array zero(0, packed_w.dtype()); - array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1 - packed_w = astype( - clip( - round(divide(subtract(packed_w, biases, s), scales, s), s), - zero, - n_bins, - s), - uint32, - s); - if (is_power_of_2(bits)) { - array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s); - packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s); - packed_w = sum( - multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); - } else { - // This is slow but we have fast GPU/CPU versions of this function so we - // shouldn't be here often. - packed_w = expand_dims(packed_w, /* axis= */ -1, s); - packed_w = bitwise_and( - right_shift(packed_w, arange(bits, uint32, s), s), - array({1}, uint32), - s); - auto new_shape = packed_w.shape(); - new_shape[new_shape.size() - 2] = -1; - new_shape.back() = 32; - packed_w = reshape(packed_w, new_shape, s); - array shifts = arange(32, uint32, s); - packed_w = - sum(left_shift(packed_w, shifts, s), - /* axis= */ -1, - /* keepdims= */ false, - s); - } - return packed_w; -} - -std::tuple -affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { - auto s = to_stream(s_); - - if (group_size != 32 && group_size != 64 && group_size != 128) { - std::ostringstream msg; - msg << "[quantize] The requested group size " << group_size - << " is not supported. The supported group sizes are 32, 64, and 128."; - throw std::invalid_argument(msg.str()); - } - - if (bits < 2 || bits > 8 || bits == 7) { - std::ostringstream msg; - msg << "[quantize] The requested number of bits " << bits - << " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8."; - throw std::invalid_argument(msg.str()); - } - - if (w.ndim() < 2) { - std::ostringstream msg; - msg << "[quantize] The matrix to be quantized must have at least 2 dimension " - << "but it has only " << w.ndim() << "."; - throw std::invalid_argument(msg.str()); - } - - if ((w.shape(-1) % group_size) != 0) { - std::ostringstream msg; - msg << "[quantize] The last dimension of the matrix needs to be divisible by " - << "the quantization group size " << group_size - << ". However the provided " << " matrix has shape " << w.shape(); - throw std::invalid_argument(msg.str()); - } - - auto fallback = [group_size, bits, s]( - const std::vector& inputs) -> std::vector { - auto& w = inputs[0]; - auto wshape = w.shape(); - wshape.back() = -1; - - array zero(0, float32); - array n_bins((1 << bits) - 1, float32); // 2**bits - 1 - array eps(1e-7, float32); - - array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s); - - array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); - array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); - w_max = astype(w_max, float32, s); - w_min = astype(w_min, float32, s); - - array mask = greater(abs(w_min, s), abs(w_max, s), s); - array scales = - maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); - scales = where(mask, scales, negative(scales, s), s); - array edge = where(mask, w_min, w_max, s); - array q0 = round(divide(edge, scales, s), s); - scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); - array biases = where(equal(q0, zero, s), zero, edge, s); - - packed_w = pack_and_quantize(packed_w, scales, biases, bits, s); - - scales = astype(scales, w.dtype(), s); - biases = astype(biases, w.dtype(), s); - return { - reshape(packed_w, wshape, s), - reshape(scales, wshape, s), - reshape(biases, wshape, s), - }; - }; - - auto wq_shape = w.shape(); - wq_shape.back() = w.shape(-1) * bits / 32; - auto sshape = w.shape(); - sshape.back() = w.shape(-1) / group_size; - auto outputs = array::make_arrays( - {std::move(wq_shape), sshape, sshape}, - {uint32, w.dtype(), w.dtype()}, - std::make_shared(s, fallback, group_size, bits, false), - {w}); - return {outputs[0], outputs[1], outputs[2]}; -} - -array affine_dequantize( - const array& w, - const array& scales, - const array& biases, - int group_size, - int bits, - StreamOrDevice s_) { - if (bits <= 0) { - std::ostringstream msg; - msg << "[dequantize] Invalid value for bits: " << bits; - throw std::invalid_argument(msg.str()); - } - if (group_size <= 0) { - std::ostringstream msg; - msg << "[dequantize] Invalid value for group_size: " << group_size; - throw std::invalid_argument(msg.str()); - } - if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) { - std::ostringstream msg; - msg << "[quantize] The matrix to be quantized must have at least 2 dimension " - << "but it has only " << w.ndim() << "."; - throw std::invalid_argument(msg.str()); - } - - auto wshape = w.shape(); - auto sshape = scales.shape(); - auto bshape = biases.shape(); - wshape.back() = -1; - sshape.back() = -1; - bshape.back() = -1; - - if (wshape != sshape || wshape != bshape) { - throw std::invalid_argument( - "[dequantize] Shape of scales and biases does not match the matrix"); - } - - if (w.dtype() != uint32) { - throw std::invalid_argument( - "[dequantize] The matrix should be given as a uint32"); - } - - // Packing into uint32 - int out_size = w.shape(-1) * 32 / bits; - - if (out_size != scales.shape(-1) * group_size) { - std::ostringstream msg; - msg << "[dequantize] Shape of scales and biases does not match the matrix " - << "given the quantization parameters. Provided matrix of shape " - << w.shape() << " and scales/biases of shape " << scales.shape() - << " with group_size=" << group_size << " and bits=" << bits << "."; - throw std::invalid_argument(msg.str()); - } - - auto s = to_stream(s_); - - auto fallback = - [wshape = std::move(wshape), - sshape = std::move(sshape), - group_size, - bits, - s](const std::vector& inputs) mutable -> std::vector { - auto w = inputs[0]; - auto& scales = inputs[1]; - auto& biases = inputs[2]; - if (is_power_of_2(bits)) { - std::vector parts; - for (int start = 0; start < 32; start += bits) { - int shift_left = 32 - (start + bits); - int shift_right = shift_left + start; - - parts.push_back(expand_dims( - right_shift( - left_shift(w, array(32 - (start + bits), uint32), s), - array(32 - bits, uint32), - s), - -1, - s)); - } - w = concatenate(parts, -1, s); - } else { - w = expand_dims(w, /* axis= */ -1, s); - w = bitwise_and( - right_shift(w, arange(32, uint32, s), s), array({1}, uint32), s); - auto new_shape = w.shape(); - new_shape[new_shape.size() - 2] = -1; - new_shape.back() = bits; - w = reshape(w, new_shape, s); - array shifts = arange(bits, uint32, s); - w = sum( - left_shift(w, shifts, s), /* axis= */ -1, /* keepdims= */ false, s); - } - - // Dequantize - wshape.push_back(group_size); - w = reshape(w, wshape, s); - w = multiply(w, expand_dims(scales, -1, s), s); - w = add(w, expand_dims(biases, -1, s), s); - w = reshape(w, sshape, s); - - return {w}; - }; - - if (s.device == Device::gpu) { - auto out_shape = w.shape(); - out_shape.back() = out_size; - return array( - std::move(out_shape), - scales.dtype(), - std::make_shared(s, fallback, group_size, bits, true), - {w, scales, biases}); - } - return fallback({w, scales, biases})[0]; -} - -bool AffineQuantize::is_equivalent(const Primitive& other) const { - const AffineQuantize& p_other = static_cast(other); +bool Quantize::is_equivalent(const Primitive& other) const { + const Quantize& p_other = static_cast(other); return ( p_other.group_size_ == group_size_ && p_other.bits_ == bits_ && - p_other.dequantize_ == dequantize_); + p_other.mode_ == mode_ && p_other.dequantize_ == dequantize_); } -std::vector AffineQuantize::output_shapes( - const std::vector& inputs) { +std::vector Quantize::output_shapes(const std::vector& inputs) { auto& w = inputs[0]; if (dequantize_) { auto out_size = w.shape(-1) * 32 / bits_; @@ -1022,8 +781,12 @@ std::vector AffineQuantize::output_shapes( wq_shape.back() = w.shape(-1) * bits_ / 32; auto sshape = w.shape(); sshape.back() = w.shape(-1) / group_size_; - auto bshape = sshape; - return {std::move(wq_shape), std::move(sshape), std::move(bshape)}; + if (inputs.size() == 2) { + return {std::move(wq_shape), std::move(sshape)}; + } else { + auto bshape = sshape; + return {std::move(wq_shape), std::move(sshape), std::move(bshape)}; + } } } diff --git a/mlx/fast.h b/mlx/fast.h index d154e4753..10f9ced96 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -52,20 +52,6 @@ array scaled_dot_product_attention( const std::vector& mask_arrs = {}, StreamOrDevice s = {}); -std::tuple affine_quantize( - const array& w, - int group_size = 64, - int bits = 4, - StreamOrDevice s = {}); - -array affine_dequantize( - const array& w, - const array& scales, - const array& biases, - int group_size = 64, - int bits = 4, - StreamOrDevice s = {}); - using TemplateArg = std::variant; using ScalarArg = std::variant; diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index e0e83f726..fd6ba8fed 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -245,17 +245,19 @@ class ScaledDotProductAttention : public Custom { bool do_causal_; }; -class AffineQuantize : public Custom { +class Quantize : public Custom { public: - explicit AffineQuantize( + explicit Quantize( Stream stream, std::function(std::vector)> fallback, int group_size, int bits, + QuantizationMode mode, bool dequantize) : Custom(stream, fallback), group_size_(group_size), bits_(bits), + mode_(mode), dequantize_(dequantize) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) @@ -264,17 +266,18 @@ class AffineQuantize : public Custom { void eval_gpu(const std::vector& inputs, std::vector& outputs) override; - DEFINE_NAME(AffineQuantize); + DEFINE_NAME(Quantize); bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { - return std::make_tuple(nullptr, group_size_, bits_, dequantize_); + return std::make_tuple(nullptr, group_size_, bits_, mode_, dequantize_); } private: int group_size_; int bits_; + QuantizationMode mode_; bool dequantize_; }; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c8583c72f..a2271c4fd 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -10,7 +10,7 @@ #include #include -#include "mlx/fast.h" +#include "mlx/fast_primitives.h" #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/transforms.h" @@ -76,7 +76,7 @@ std::pair extract_quantized_matmul_dims( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, bool transpose, int group_size, int bits) { @@ -87,11 +87,11 @@ std::pair extract_quantized_matmul_dims( throw std::invalid_argument(msg.str()); } - if (scales.shape() != biases.shape()) { + if (biases && scales.shape() != biases->shape()) { std::ostringstream msg; msg << "[" << tag << "] Scales and biases should have the same shape. " << "Received scales with shape " << scales.shape() - << " and biases with " << biases.shape(); + << " and biases with " << biases->shape(); throw std::invalid_argument(msg.str()); } @@ -99,9 +99,9 @@ std::pair extract_quantized_matmul_dims( w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) { std::ostringstream msg; msg << "[" << tag - << "] Weight, scales and biases should have the same batch shape. " + << "] Weight and scales should have the same batch shape. " << "Received weight with shape " << w.shape() << ", scales with " - << scales.shape() << " and biases with " << biases.shape(); + << scales.shape() << "."; throw std::invalid_argument(msg.str()); } @@ -4021,30 +4021,76 @@ array conv_general( {in, wt}); } +void validate_mode(std::string_view tag, const std::string& mode) { + if (mode != "affine" && mode != "mxfp4") { + std::ostringstream msg; + msg << "[" << tag << "] Invalid quantization mode '" << mode << "'."; + throw std::invalid_argument(msg.str()); + } +} + +Dtype validate_mode_with_type( + std::string_view tag, + const array& scales, + const std::optional& biases, + const std::string& mode) { + validate_mode(tag, mode); + if (mode == "affine") { + if (!biases) { + std::ostringstream msg; + msg << "[" << tag << "] Biases must be provided for affine quantization."; + throw std::invalid_argument(msg.str()); + } + auto dtype = result_type(scales, *biases); + if (!issubdtype(dtype, floating)) { + std::ostringstream msg; + msg << "[" << tag << "] Only real floating types are supported but " + << "scales.dtype() == " << scales.dtype() + << " and biases.dtype() == " << biases->dtype() << "."; + throw std::invalid_argument(msg.str()); + } + return dtype; + } + if (biases) { + std::ostringstream msg; + msg << "[" << tag << "] Biases must be null for quantization mode '" << mode + << "'."; + throw std::invalid_argument(msg.str()); + } + return bfloat16; +} + array quantized_matmul( array x, array w, array scales, - array biases, + std::optional biases /* = std::nullopt */, bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, + const std::string& mode /* = "affine" */, StreamOrDevice s /* = {} */) { // Check and extract the quantized matrix shape against x auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( "quantized_matmul", x, w, scales, biases, transpose, group_size, bits); - auto dtype = result_type(x, scales, biases); + auto dtype = + validate_mode_with_type("quantized_matmul", scales, biases, mode); + dtype = promote_types(x.dtype(), dtype); + if (!issubdtype(dtype, floating)) { std::ostringstream msg; msg << "[quantized_matmul] Only real floating types are supported but " - << "the passed types where x.dtype() == " << x.dtype() - << ", scales.dtype() == " << scales.dtype() - << " and biases.dtype() == " << biases.dtype(); + << "x.dtype() == " << x.dtype() << "."; throw std::invalid_argument(msg.str()); } - std::vector inputs = { - astype(x, dtype), w, astype(scales, dtype), astype(biases, dtype)}; + std::vector inputs; + if (mode == "affine") { + inputs = { + astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)}; + } else { + inputs = {x, w, scales}; + } if (x.ndim() > 2 && w.ndim() > 2) { inputs = broadcast_arrays(inputs, {-2, -1}, s); @@ -4056,48 +4102,447 @@ array quantized_matmul( std::move(out_shape), dtype, std::make_shared( - to_stream(s), group_size, bits, transpose), + to_stream(s), + group_size, + bits, + string_to_quantization_mode(mode), + transpose), std::move(inputs)); } -std::tuple quantize( +array pack_and_quantize( + array& packed_w, + const array& scales, + const array& biases, + int bits, + const Stream& s) { + int el_per_int = 32 / bits; + array zero(0, packed_w.dtype()); + array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1 + packed_w = astype( + clip( + round(divide(subtract(packed_w, biases, s), scales, s), s), + zero, + n_bins, + s), + uint32, + s); + if (is_power_of_2(bits)) { + array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s); + packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s); + packed_w = sum( + multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); + } else { + // This is slow but we have fast GPU/CPU versions of this function so we + // shouldn't be here often. + packed_w = expand_dims(packed_w, /* axis= */ -1, s); + packed_w = bitwise_and( + right_shift(packed_w, arange(bits, uint32, s), s), + array({1}, uint32), + s); + auto new_shape = packed_w.shape(); + new_shape[new_shape.size() - 2] = -1; + new_shape.back() = 32; + packed_w = reshape(packed_w, new_shape, s); + array shifts = arange(32, uint32, s); + packed_w = + sum(left_shift(packed_w, shifts, s), + /* axis= */ -1, + /* keepdims= */ false, + s); + } + return packed_w; +} + +std::vector +affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { + auto s = to_stream(s_); + if (group_size != 32 && group_size != 64 && group_size != 128) { + std::ostringstream msg; + msg << "[quantize] The requested group size " << group_size + << " is not supported. The supported group sizes are 32, 64, and 128."; + throw std::invalid_argument(msg.str()); + } + + if (bits < 2 || bits > 8 || bits == 7) { + std::ostringstream msg; + msg << "[quantize] The requested number of bits " << bits + << " is not supported. The supported bits are 2, 3, 4, 5, 6 and 8."; + throw std::invalid_argument(msg.str()); + } + + auto fallback = [group_size, bits, s]( + const std::vector& inputs) -> std::vector { + auto& w = inputs[0]; + auto wshape = w.shape(); + wshape.back() = -1; + + array zero(0, float32); + array n_bins((1 << bits) - 1, float32); // 2**bits - 1 + array eps(1e-7, float32); + + array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s); + + array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); + array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); + w_max = astype(w_max, float32, s); + w_min = astype(w_min, float32, s); + + array mask = greater(abs(w_min, s), abs(w_max, s), s); + array scales = + maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); + scales = where(mask, scales, negative(scales, s), s); + array edge = where(mask, w_min, w_max, s); + array q0 = round(divide(edge, scales, s), s); + scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); + array biases = where(equal(q0, zero, s), zero, edge, s); + + packed_w = pack_and_quantize(packed_w, scales, biases, bits, s); + + scales = astype(scales, w.dtype(), s); + biases = astype(biases, w.dtype(), s); + return { + reshape(packed_w, wshape, s), + reshape(scales, wshape, s), + reshape(biases, wshape, s), + }; + }; + + auto wq_shape = w.shape(); + wq_shape.back() = w.shape(-1) * bits / 32; + auto sshape = w.shape(); + sshape.back() = w.shape(-1) / group_size; + return array::make_arrays( + {std::move(wq_shape), sshape, sshape}, + {uint32, w.dtype(), w.dtype()}, + std::make_shared( + s, fallback, group_size, bits, QuantizationMode::Affine, false), + {w}); +} + +std::vector quantize( const array& w, int group_size /* = 64 */, int bits /* = 4 */, + const std::string& mode /* = "affine" */, StreamOrDevice s /* = {} */) { - return fast::affine_quantize(w, group_size, bits, s); + validate_mode("quantize", mode); + if (!issubdtype(w.dtype(), floating)) { + std::ostringstream msg; + msg << "[quantize] Only real floating types can be quantized " + << "but w has type " << w.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + + if (w.ndim() < 2) { + std::ostringstream msg; + msg << "[quantize] The matrix to be quantized must have at least 2 dimension " + << "but it has only " << w.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + + if ((w.shape(-1) % group_size) != 0) { + std::ostringstream msg; + msg << "[quantize] The last dimension of the matrix needs to be divisible by " + << "the quantization group size " << group_size + << ". However the provided " << " matrix has shape " << w.shape(); + throw std::invalid_argument(msg.str()); + } + + if (mode == "affine") { + return affine_quantize(w, group_size, bits, s); + } else { + if (group_size != 32) { + std::ostringstream msg; + msg << "[quantize] mxfp4 quantization requires group size 32 " + << "but got " << group_size << "."; + throw std::invalid_argument(msg.str()); + } + if (bits != 4) { + std::ostringstream msg; + msg << "[quantize] mxfp4 quantization requires bits to be 4 " + << "but got " << bits << "."; + throw std::invalid_argument(msg.str()); + } + + auto lut = array({ + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f, + }); + lut = astype(lut, w.dtype(), s); + + auto new_shape = w.shape(); + new_shape.back() = -1; + auto wq = reshape(w, {-1, group_size}, s); + auto scales = + divide(max(abs(wq, s), -1, true, s), array(6.0f, w.dtype()), s); + scales = astype(log2(scales, s), int32, s); + wq = divide(wq, power(array(2.0f, w.dtype()), scales, s), s); + scales = astype(add(scales, array(127, int32), s), uint8, s); + wq = argmin(abs(subtract(expand_dims(wq, -1, s), lut, s), s), -1, false, s); + auto shifts = power(array(2, uint32), arange(0, 32, 4, uint32, s), s); + wq = reshape(wq, {-1, group_size / 8, 8}, s); + wq = sum(multiply(wq, shifts, s), -1, false, s); + wq = reshape(wq, new_shape, s); + scales = reshape(scales, new_shape, s); + return {std::move(wq), std::move(scales)}; + } +} + +array affine_dequantize( + const array& w, + const array& scales, + const array& biases, + int group_size, + int bits, + StreamOrDevice s_) { + if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) { + std::ostringstream msg; + msg << "[quantize] The matrix to be quantized must have at least 2 dimension " + << "but it has only " << w.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + + auto wshape = w.shape(); + auto sshape = scales.shape(); + auto bshape = biases.shape(); + wshape.back() = -1; + sshape.back() = -1; + bshape.back() = -1; + + if (wshape != sshape || wshape != bshape) { + throw std::invalid_argument( + "[dequantize] Shape of scales and biases does not match the matrix"); + } + + // Packing into uint32 + int out_size = w.shape(-1) * 32 / bits; + + if (out_size != scales.shape(-1) * group_size) { + std::ostringstream msg; + msg << "[dequantize] Shape of scales and biases does not match the matrix " + << "given the quantization parameters. Provided matrix of shape " + << w.shape() << " and scales/biases of shape " << scales.shape() + << " with group_size=" << group_size << " and bits=" << bits << "."; + throw std::invalid_argument(msg.str()); + } + + auto s = to_stream(s_); + + auto fallback = + [wshape = std::move(wshape), + sshape = std::move(sshape), + group_size, + bits, + s](const std::vector& inputs) mutable -> std::vector { + auto w = inputs[0]; + auto& scales = inputs[1]; + auto& biases = inputs[2]; + if (is_power_of_2(bits)) { + std::vector parts; + for (int start = 0; start < 32; start += bits) { + int shift_left = 32 - (start + bits); + int shift_right = shift_left + start; + + parts.push_back(expand_dims( + right_shift( + left_shift(w, array(32 - (start + bits), uint32), s), + array(32 - bits, uint32), + s), + -1, + s)); + } + w = concatenate(parts, -1, s); + } else { + w = expand_dims(w, /* axis= */ -1, s); + w = bitwise_and( + right_shift(w, arange(32, uint32, s), s), array({1}, uint32), s); + auto new_shape = w.shape(); + new_shape[new_shape.size() - 2] = -1; + new_shape.back() = bits; + w = reshape(w, new_shape, s); + array shifts = arange(bits, uint32, s); + w = sum( + left_shift(w, shifts, s), /* axis= */ -1, /* keepdims= */ false, s); + } + + // Dequantize + wshape.push_back(group_size); + w = reshape(w, wshape, s); + w = multiply(w, expand_dims(scales, -1, s), s); + w = add(w, expand_dims(biases, -1, s), s); + w = reshape(w, sshape, s); + + return {w}; + }; + + if (s.device == Device::gpu) { + auto out_shape = w.shape(); + out_shape.back() = out_size; + return array( + std::move(out_shape), + scales.dtype(), + std::make_shared( + s, fallback, group_size, bits, QuantizationMode::Affine, true), + {w, scales, biases}); + } + return fallback({w, scales, biases})[0]; } array dequantize( const array& w, const array& scales, - const array& biases, + const std::optional& biases /* = std::nullopt */, int group_size /* = 64 */, int bits /* = 4 */, + const std::string& mode /* = "affine" */, StreamOrDevice s /* = {} */) { - return fast::affine_dequantize(w, scales, biases, group_size, bits, s); + validate_mode_with_type("dequantize", scales, biases, mode); + if (bits <= 0) { + std::ostringstream msg; + msg << "[dequantize] Invalid value for bits: " << bits; + throw std::invalid_argument(msg.str()); + } + if (group_size <= 0) { + std::ostringstream msg; + msg << "[dequantize] Invalid value for group_size: " << group_size; + throw std::invalid_argument(msg.str()); + } + if (w.dtype() != uint32) { + throw std::invalid_argument( + "[dequantize] The matrix should be given as a uint32"); + } + + if (mode == "affine") { + return affine_dequantize(w, scales, *biases, group_size, bits, s); + } else { + if (group_size != 32) { + std::ostringstream msg; + msg << "[dequantize] mxfp4 quantization requires group size 32 " + << "but got " << group_size << "."; + throw std::invalid_argument(msg.str()); + } + if (bits != 4) { + std::ostringstream msg; + msg << "[dequantize] mxfp4 quantization requires bits to be 4 " + << "but got " << bits << "."; + throw std::invalid_argument(msg.str()); + } + + if (w.ndim() < 2 || scales.ndim() < 2) { + std::ostringstream msg; + msg << "[quantize] The matrix to be quantized must have at least 2 dimension " + << "but it has only " << w.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + + auto wshape = w.shape(); + auto sshape = scales.shape(); + wshape.back() = -1; + sshape.back() = -1; + + if (wshape != sshape) { + throw std::invalid_argument( + "[dequantize] Shape of scales does not match the matrix"); + } + + if (w.dtype() != uint32) { + throw std::invalid_argument( + "[dequantize] The matrix should be given as a uint32"); + } + + // Packing into uint32 + int out_size = w.shape(-1) * 32 / bits; + + if (out_size != scales.shape(-1) * group_size) { + std::ostringstream msg; + msg << "[dequantize] Shape of scales does not match the matrix " + << "given the quantization parameters. Provided matrix of shape " + << w.shape() << " and scales of shape " << scales.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + auto dtype = bfloat16; + auto lut = array( + { + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f, + }, + dtype); + + auto what = view(reshape(w, {-1, group_size / 8}, s), int8, s); + + auto idx_lo = bitwise_and(what, array(0x0F, int8), s); + auto idx_hi = right_shift(what, array(4, int8), s); + auto lo = gather(lut, idx_lo, 0, {1}, s); + auto hi = gather(lut, idx_hi, 0, {1}, s); + what = flatten(concatenate({lo, hi}, -1, s), -2, -1, s); + auto exponent = subtract(astype(scales, dtype, s), array(127, dtype), s); + exponent = reshape(exponent, {-1, 1}, s); + return reshape( + multiply(power(array(2.0f, dtype), exponent, s), what, s), wshape, s); + } } array gather_qmm( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases /* = std::nullopt */, std::optional lhs_indices_ /* = std::nullopt */, std::optional rhs_indices_ /* = std::nullopt */, bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, + const std::string& mode /* = "affine" */, bool sorted_indices /* = false */, StreamOrDevice s /* = {} */) { if (!lhs_indices_ && !rhs_indices_) { return quantized_matmul( - x, w, scales, biases, transpose, group_size, bits, s); + x, w, scales, biases, transpose, group_size, bits, mode, s); } auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( "gather_qmm", x, w, scales, biases, transpose, group_size, bits); + auto out_type = validate_mode_with_type("gather_qmm", scales, biases, mode); + out_type = promote_types(x.dtype(), out_type); + + if (!issubdtype(out_type, floating)) { + std::ostringstream msg; + msg << "[gather_qmm] Only real floating types are supported but " + << "x.dtype() == " << x.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + // Extract indices and broadcast them array lhs_indices = indices_or_default(lhs_indices_, x, s); array rhs_indices = indices_or_default(rhs_indices_, w, s); @@ -4113,6 +4558,12 @@ array gather_qmm( throw std::invalid_argument( "[gather_qmm] Got rhs_indices with invalid dtype. Indices must be integral."); } + if (x.ndim() < 2) { + std::ostringstream msg; + msg << "[gather_qmm] Non-quantized input must have at least two" + << " dimensions but got input with shape " << x.shape() << "."; + throw std::invalid_argument(msg.str()); + } lhs_indices = astype(lhs_indices, uint32, s); rhs_indices = astype(rhs_indices, uint32, s); @@ -4121,10 +4572,23 @@ array gather_qmm( auto out_shape = lhs_indices.shape(); out_shape.push_back(x.shape(-2)); out_shape.push_back(w_outer_dims); - - // and output type - auto out_type = result_type(x, scales, biases); - + std::vector inputs; + if (mode == "affine") { + inputs = { + astype(x, out_type, s), + std::move(w), + astype(scales, out_type, s), + astype(*biases, out_type, s), + std::move(lhs_indices), + std::move(rhs_indices)}; + } else { + inputs = { + astype(x, out_type, s), + std::move(w), + std::move(scales), + std::move(lhs_indices), + std::move(rhs_indices)}; + } return array( std::move(out_shape), out_type, @@ -4132,15 +4596,11 @@ array gather_qmm( to_stream(s), group_size, bits, + string_to_quantization_mode(mode), transpose, sorted_indices && !rhs_indices_, sorted_indices && !lhs_indices_), - {astype(x, out_type, s), - std::move(w), - astype(scales, out_type, s), - astype(biases, out_type, s), - std::move(lhs_indices), - std::move(rhs_indices)}); + std::move(inputs)); } array tensordot( diff --git a/mlx/ops.h b/mlx/ops.h index 596d6d287..826f6d47b 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1322,26 +1322,29 @@ array quantized_matmul( array x, array w, array scales, - array biases, + std::optional biases = std::nullopt, bool transpose = true, int group_size = 64, int bits = 4, + const std::string& mode = "affine", StreamOrDevice s = {}); /** Quantize a matrix along its last axis */ -std::tuple quantize( +std::vector quantize( const array& w, int group_size = 64, int bits = 4, + const std::string& mode = "affine", StreamOrDevice s = {}); /** Dequantize a matrix produced by quantize() */ array dequantize( const array& w, const array& scales, - const array& biases, + const std::optional& biases = std::nullopt, int group_size = 64, int bits = 4, + const std::string& mode = "affine", StreamOrDevice s = {}); /** Compute matrix products with matrix-level gather. */ @@ -1349,12 +1352,13 @@ array gather_qmm( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases = std::nullopt, std::optional lhs_indices = std::nullopt, std::optional rhs_indices = std::nullopt, bool transpose = true, int group_size = 64, int bits = 4, + const std::string& mode = "affine", bool sorted_indices = false, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 980a1f7c3..977e5c62a 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3208,6 +3208,22 @@ std::pair, std::vector> Power::vmap( return {{power(a, b, stream())}, {to_ax}}; } +std::string quantization_mode_to_string(QuantizationMode mode) { + if (mode == QuantizationMode::Affine) { + return "affine"; + } else { + return "mxfp4"; + } +} + +QuantizationMode string_to_quantization_mode(const std::string& mode) { + if (mode == "affine") { + return QuantizationMode::Affine; + } else { + return QuantizationMode::Mxfp4; + } +} + std::pair, std::vector> QuantizedMatmul::vmap( const std::vector& inputs, const std::vector& axes) { @@ -3234,6 +3250,7 @@ std::vector QuantizedMatmul::vjp( !transpose_, group_size_, bits_, + quantization_mode_to_string(mode_), stream())); } @@ -3242,6 +3259,10 @@ std::vector QuantizedMatmul::vjp( throw std::runtime_error( "[QuantizedMatmul::vjp] no gradient wrt the quantized weights."); } else { + if (mode_ == QuantizationMode::Mxfp4) { + throw std::runtime_error( + "[QuantizedMatmul::vjp] no gradient wrt scales with mxfp4 quantization."); + } if (!dsb) { int ndim = primals[1].ndim(); auto fc = flatten(cotangents[0], 0, -ndim, stream()); @@ -3262,6 +3283,7 @@ std::vector QuantizedMatmul::vjp( zeros_like(primals[3], stream()), group_size_, bits_, + quantization_mode_to_string(mode_), stream()); wq = unflatten(wq, -1, {-1, group_size_}, stream()); vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream())); @@ -3287,13 +3309,14 @@ std::vector QuantizedMatmul::jvp( transpose_, group_size_, bits_, + quantization_mode_to_string(mode_), stream())}; } bool QuantizedMatmul::is_equivalent(const Primitive& other) const { const QuantizedMatmul& qm_other = static_cast(other); return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && - transpose_ == qm_other.transpose_; + mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_; } std::vector QuantizedMatmul::output_shapes( @@ -3348,6 +3371,7 @@ std::vector GatherQMM::vjp( !transpose_, group_size_, bits_, + quantization_mode_to_string(mode_), sorted, stream()); if (sorted && no_broadcast) { @@ -3368,14 +3392,19 @@ std::vector GatherQMM::vjp( // gradient wrt to the indices is undefined else if (arg > 3) { throw std::runtime_error( - "GatherQMM::vjp cannot compute the gradient wrt the indices."); + "[GatherQMM::vjp] cannot compute the gradient wrt the indices."); } // gradient wrt to w_q, scales or biases else if (arg == 1) { throw std::runtime_error( - "GatherQMM::vjp no gradient wrt the quantized weights."); + "[GatherQMM::vjp] no gradient wrt the quantized weights."); } else { + if (mode_ == QuantizationMode::Mxfp4) { + throw std::runtime_error( + "[GatherQMM::vjp] no gradient wrt scales with mxfp4 quantization."); + } + if (!dsb) { auto shape = w.shape(); shape.pop_back(); @@ -3406,6 +3435,7 @@ std::vector GatherQMM::vjp( zeros_like(biases, stream()), group_size_, bits_, + quantization_mode_to_string(mode_), stream()), -1, {-1, group_size_}, @@ -3430,7 +3460,7 @@ std::vector GatherQMM::jvp( bool GatherQMM::is_equivalent(const Primitive& other) const { const GatherQMM& qm_other = static_cast(other); return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && - transpose_ == qm_other.transpose_; + mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_; } std::pair, std::vector> RandomBits::vmap( diff --git a/mlx/primitives.h b/mlx/primitives.h index 277e42a0b..986675f3a 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -151,6 +151,11 @@ class UnaryPrimitive : public Primitive { UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete; }; +enum class QuantizationMode { Affine, Mxfp4 }; + +std::string quantization_mode_to_string(QuantizationMode mode); +QuantizationMode string_to_quantization_mode(const std::string& mode); + class Abs : public UnaryPrimitive { public: explicit Abs(Stream stream) : UnaryPrimitive(stream) {} @@ -1597,10 +1602,12 @@ class QuantizedMatmul : public UnaryPrimitive { Stream stream, int group_size, int bits, + QuantizationMode mode, bool transpose) : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), + mode_(mode), transpose_(transpose) {} void eval_cpu(const std::vector& inputs, array& out) override; @@ -1612,12 +1619,13 @@ class QuantizedMatmul : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { - return std::make_tuple(group_size_, bits_, transpose_); + return std::make_tuple(group_size_, bits_, mode_, transpose_); } private: int group_size_; int bits_; + QuantizationMode mode_; bool transpose_; }; @@ -1627,12 +1635,14 @@ class GatherQMM : public UnaryPrimitive { Stream stream, int group_size, int bits, + QuantizationMode mode, bool transpose, bool left_sorted = false, bool right_sorted = false) : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), + mode_(mode), transpose_(transpose), left_sorted_(left_sorted), right_sorted_(right_sorted) {} @@ -1646,12 +1656,13 @@ class GatherQMM : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple( - group_size_, bits_, transpose_, left_sorted_, right_sorted_); + group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_); } private: int group_size_; int bits_; + QuantizationMode mode_; bool transpose_; bool left_sorted_; bool right_sorted_; diff --git a/python/mlx/nn/layers/embedding.py b/python/mlx/nn/layers/embedding.py index 1e15a59cc..1edf7e3a5 100644 --- a/python/mlx/nn/layers/embedding.py +++ b/python/mlx/nn/layers/embedding.py @@ -39,6 +39,6 @@ class Embedding(Module): """ return x @ self.weight.T - def to_quantized(self, group_size: int = 64, bits: int = 4): + def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"): """Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer.""" - return QuantizedEmbedding.from_embedding(self, group_size, bits) + return QuantizedEmbedding.from_embedding(self, group_size, bits, mode) diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 63caa911c..84a4d8327 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -70,9 +70,9 @@ class Linear(Module): x = x @ self["weight"].T return x - def to_quantized(self, group_size: int = 64, bits: int = 4): + def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = "affine"): """Return a :obj:`QuantizedLinear` layer that approximates this layer.""" - return QuantizedLinear.from_linear(self, group_size, bits) + return QuantizedLinear.from_linear(self, group_size, bits, mode) class Bilinear(Module): diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 2d6dc0882..669162e68 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -12,6 +12,8 @@ def quantize( model: Module, group_size: int = 64, bits: int = 4, + *, + mode: str = "affine", class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None, ): """Quantize the sub-modules of a module according to a predicate. @@ -26,6 +28,8 @@ def quantize( :func:`mlx.core.quantize`). Default: ``64``. bits (int): The number of bits per parameter (see :func:`mlx.core.quantize`). Default: ``4``. + mode (str): The quantization method to use (see + :func:`mlx.core.quantize`). Default: ``"affine"``. class_predicate (Optional[Callable]): A callable which receives the :obj:`Module` path and :obj:`Module` itself and returns ``True`` or a dict of params for `to_quantized` if it should be quantized and @@ -39,7 +43,7 @@ def quantize( if bool_or_params := class_predicate(path, m): if hasattr(m, "to_quantized"): if isinstance(bool_or_params, bool): - return m.to_quantized(group_size=group_size, bits=bits) + return m.to_quantized(group_size=group_size, bits=bits, mode=mode) elif isinstance(bool_or_params, dict): return m.to_quantized(**bool_or_params) else: @@ -72,6 +76,8 @@ class QuantizedEmbedding(Module): weight. See :func:`~mlx.core.quantize`. Default: ``64``. bits (int, optional): The bit width to use for the quantized weight. See :func:`~mlx.core.quantize`. Default: ``4``. + mode (str): The quantization method to use (see + :func:`mlx.core.quantize`). Default: ``"affine"``. """ def __init__( @@ -80,17 +86,23 @@ class QuantizedEmbedding(Module): dims: int, group_size: int = 64, bits: int = 4, + mode: str = "affine", ): super().__init__() # Quantization config self.group_size = group_size self.bits = bits + self.mode = mode # Initialize the quantized weight scale = math.sqrt(1 / dims) weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale) - self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) + self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode) + if mode == "affine": + self.scales, self.biases = scales_biases + else: + (self.scales,) = scales_biases self.num_embeddings = num_embeddings self.dims = dims @@ -98,12 +110,14 @@ class QuantizedEmbedding(Module): self.freeze() def __call__(self, x): + biases = self.get("biases") return mx.dequantize( self["weight"][x], scales=self["scales"][x], - biases=self["biases"][x], + biases=biases[x] if biases is not None else None, group_size=self.group_size, bits=self.bits, + mode=self.mode, ) def as_linear(self, x): @@ -117,28 +131,40 @@ class QuantizedEmbedding(Module): x, self["weight"], scales=self["scales"], - biases=self["biases"], + biases=self.get("biases"), transpose=True, group_size=self.group_size, bits=self.bits, + mode=self.mode, ) def _extra_repr(self): return ( f"{self.num_embeddings}, {self.dims}, " - f"group_size={self.group_size}, bits={self.bits}" + f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}" ) @classmethod def from_embedding( - cls, embedding_layer: Module, group_size: int = 64, bits: int = 4 + cls, + embedding_layer: Module, + group_size: int = 64, + bits: int = 4, + mode: str = "affine", ): """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" embedding_dims, dims = embedding_layer.weight.shape - ql = cls(embedding_dims, dims, group_size, bits) - ql.weight, ql.scales, ql.biases = mx.quantize( - embedding_layer.weight, group_size, bits + ql = cls(embedding_dims, dims, group_size, bits, mode=mode) + ql.weight, *scales_biases = mx.quantize( + embedding_layer.weight, + group_size, + bits, + mode=mode, ) + if mode == "affine": + ql.scales, ql.biases = scales_biases + else: + (ql.scales,) = scales_biases return ql @@ -161,6 +187,8 @@ class QuantizedLinear(Module): weight. See :func:`~mlx.core.quantize`. Default: ``64``. bits (int, optional): The bit width to use for the quantized weight. See :func:`~mlx.core.quantize`. Default: ``4``. + mode (str): The quantization method to use (see + :func:`mlx.core.quantize`). Default: ``"affine"``. """ def __init__( @@ -170,12 +198,14 @@ class QuantizedLinear(Module): bias: bool = True, group_size: int = 64, bits: int = 4, + mode: str = "affine", ): super().__init__() # Quantization config self.group_size = group_size self.bits = bits + self.mode = mode # Initialize the quantized weight scale = math.sqrt(1 / input_dims) @@ -184,7 +214,11 @@ class QuantizedLinear(Module): high=scale, shape=(output_dims, input_dims), ) - self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) + self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode) + if mode == "affine": + self.scales, self.biases = scales_biases + else: + (self.scales,) = scales_biases # And bias if needed if bias: @@ -198,7 +232,7 @@ class QuantizedLinear(Module): in_dims *= 32 // self.bits return ( f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " - f"group_size={self.group_size}, bits={self.bits}" + f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}" ) def __call__(self, x): @@ -206,23 +240,38 @@ class QuantizedLinear(Module): x, self["weight"], scales=self["scales"], - biases=self["biases"], + biases=self.get("biases"), transpose=True, group_size=self.group_size, bits=self.bits, + mode=self.mode, ) if "bias" in self: x = x + self["bias"] return x @classmethod - def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): + def from_linear( + cls, + linear_layer: Module, + group_size: int = 64, + bits: int = 4, + mode: str = "affine", + ): """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape - ql = cls(input_dims, output_dims, False, group_size, bits) - ql.weight, ql.scales, ql.biases = mx.quantize( - linear_layer.weight, group_size, bits + ql = cls(input_dims, output_dims, False, group_size, bits, mode=mode) + ql.weight, *scales_biases = mx.quantize( + linear_layer.weight, + group_size, + bits, + mode=mode, ) + if mode == "affine": + ql.scales, ql.biases = scales_biases + else: + (ql.scales,) = scales_biases + if "bias" in linear_layer: ql.bias = linear_layer.bias diff --git a/python/src/ops.cpp b/python/src/ops.cpp index f2a27e282..7b585af6e 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4153,14 +4153,15 @@ void init_ops(nb::module_& m) { nb::arg(), nb::arg(), "scales"_a, - "biases"_a, + "biases"_a = nb::none(), "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, + "mode"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + "def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform the matrix multiplication with the quantized matrix ``w``. The quantization uses one floating point scale and bias per ``group_size`` of @@ -4171,7 +4172,8 @@ void init_ops(nb::module_& m) { x (array): Input array w (array): Quantized matrix packed in unsigned integers scales (array): The scales to use per ``group_size`` elements of ``w`` - biases (array): The biases to use per ``group_size`` elements of ``w`` + biases (array, optional): The biases to use per ``group_size`` + elements of ``w``. Default: ``None``. transpose (bool, optional): Defines whether to multiply with the transposed ``w`` or not, namely whether we are performing ``x @ w.T`` or ``x @ w``. Default: ``True``. @@ -4179,6 +4181,7 @@ void init_ops(nb::module_& m) { shares a scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in ``w``. Default: ``4``. + mode (str, optional): The quantization mode. Default: ``"affine"``. Returns: array: The result of the multiplication of ``x`` with ``w``. @@ -4189,10 +4192,11 @@ void init_ops(nb::module_& m) { nb::arg(), "group_size"_a = 64, "bits"_a = 4, + "mode"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), + "def quantize(w: array, /, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), R"pbdoc( Quantize the matrix ``w`` using ``bits`` bits per element. @@ -4203,30 +4207,11 @@ void init_ops(nb::module_& m) { .. warning:: - ``quantize`` currently only supports 2D inputs with dimensions which are multiples of 32 + ``quantize`` currently only supports 2D inputs with the second + dimension divisible by ``group_size`` - Formally, for a group of :math:`g` consecutive elements :math:`w_1` to - :math:`w_g` in a row of ``w`` we compute the quantized representation - of each element :math:`\hat{w_i}` as follows - - .. math:: - - \begin{aligned} - \alpha &= \max_i w_i \\ - \beta &= \min_i w_i \\ - s &= \frac{\alpha - \beta}{2^b - 1} \\ - \hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right). - \end{aligned} - - After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits - and is packed in an unsigned 32-bit integer from the lower to upper - bits. For instance, for 4-bit quantization we fit 8 elements in an - unsigned 32 bit integer where the 1st element occupies the 4 least - significant bits, the 2nd bits 4-7 etc. - - In order to be able to dequantize the elements of ``w`` we also need to - save :math:`s` and :math:`\beta` which are the returned ``scales`` and - ``biases`` respectively. + The supported quantization modes are ``"affine"`` and ``"mxfp4"``. They + are described in more detail below. Args: w (array): Matrix to be quantized @@ -4234,49 +4219,86 @@ void init_ops(nb::module_& m) { scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element of ``w`` in the returned quantized matrix. Default: ``4``. + mode (str, optional): The quantization mode. Default: ``"affine"``. Returns: - tuple: A tuple containing + tuple: A tuple with either two or three elements containing: * w_q (array): The quantized version of ``w`` - * scales (array): The scale to multiply each element with, namely :math:`s` - * biases (array): The biases to add to each element, namely :math:`\beta` + * scales (array): The quantization scales + * biases (array): The quantization biases (returned for ``mode=="affine"``). + + Notes: + The ``affine`` mode quantizes groups of :math:`g` consecutive + elements in a row of ``w``. For each group the quantized + representation of each element :math:`\hat{w_i}` is computed as follows: + + .. math:: + + \begin{aligned} + \alpha &= \max_i w_i \\ + \beta &= \min_i w_i \\ + s &= \frac{\alpha - \beta}{2^b - 1} \\ + \hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right). + \end{aligned} + + After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits + and is packed in an unsigned 32-bit integer from the lower to upper + bits. For instance, for 4-bit quantization we fit 8 elements in an + unsigned 32 bit integer where the 1st element occupies the 4 least + significant bits, the 2nd bits 4-7 etc. + + To dequantize the elements of ``w``, we also save :math:`s` and + :math:`\beta` which are the returned ``scales`` and + ``biases`` respectively. + + The ``mxfp4`` mode similarly quantizes groups of :math:`g` elements + of ``w``. For ``mxfp4`` the group size must be ``32``. The elements + are quantized to 4-bit precision floating-point values (E2M1) with a + shared 8-bit scale per group. Unlike ``affine`` quantization, + ``mxfp4`` does not have a bias value. More details on the format can + be found in the `specification `_. )pbdoc"); m.def( "dequantize", &mx::dequantize, nb::arg(), "scales"_a, - "biases"_a, + "biases"_a = nb::none(), "group_size"_a = 64, "bits"_a = 4, + "mode"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + "def dequantize(w: array, /, scales: array, biases: Optional[array] = = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( - Dequantize the matrix ``w`` using the provided ``scales`` and - ``biases`` and the ``group_size`` and ``bits`` configuration. - - Formally, given the notation in :func:`quantize`, we compute - :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and - :math:`\beta` as follows - - .. math:: - - w_i = s \hat{w_i} + \beta + Dequantize the matrix ``w`` using quantization parameters. Args: - w (array): Matrix to be quantized - scales (array): The scales to use per ``group_size`` elements of ``w`` - biases (array): The biases to use per ``group_size`` elements of ``w`` + w (array): Matrix to be dequantized + scales (array): The scales to use per ``group_size`` elements of ``w``. + biases (array, optional): The biases to use per ``group_size`` + elements of ``w``. Default: ``None``. group_size (int, optional): The size of the group in ``w`` that shares a scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in ``w``. Default: ``4``. + mode (str, optional): The quantization mode. Default: ``"affine"``. Returns: array: The dequantized version of ``w`` + + Notes: + The currently supported quantization modes are ``"affine"`` and ``mxfp4``. + + For ``affine`` quantization, given the notation in :func:`quantize`, + we compute :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` + and :math:`\beta` as follows + + .. math:: + + w_i = s \hat{w_i} + \beta )pbdoc"); m.def( "gather_qmm", @@ -4284,17 +4306,18 @@ void init_ops(nb::module_& m) { nb::arg(), nb::arg(), "scales"_a, - "biases"_a, + "biases"_a = nb::none(), "lhs_indices"_a = nb::none(), "rhs_indices"_a = nb::none(), "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, + "mode"_a = "affine", nb::kw_only(), "sorted_indices"_a = false, "stream"_a = nb::none(), nb::sig( - "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform quantized matrix multiplication with matrix-level gather. @@ -4310,7 +4333,8 @@ void init_ops(nb::module_& m) { x (array): Input array w (array): Quantized matrix packed in unsigned integers scales (array): The scales to use per ``group_size`` elements of ``w`` - biases (array): The biases to use per ``group_size`` elements of ``w`` + biases (array, optional): The biases to use per ``group_size`` + elements of ``w``. Default: ``None``. lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``. rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``. transpose (bool, optional): Defines whether to multiply with the @@ -4320,6 +4344,7 @@ void init_ops(nb::module_& m) { shares a scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in ``w``. Default: ``4``. + mode (str, optional): The quantization mode. Default: ``"affine"``. sorted_indices (bool, optional): May allow a faster implementation if the passed indices are sorted. Default: ``False``. diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index af5bace9a..4723bda6a 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -48,6 +48,8 @@ cuda_skip = { "TestQuantized.test_qmm_shapes", "TestQuantized.test_qmm_vjp", "TestQuantized.test_qmv", + "TestQuantized.test_mxfp4_qmv", + "TestQuantized.test_mxfp4_qvm", "TestQuantized.test_qvm", "TestQuantized.test_qvm_splitk", "TestQuantized.test_small_matrix", diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 296f6ee8d..6ded37227 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -198,6 +198,12 @@ class TestBase(mlx_tests.MLXTestCase): self.assertTrue(isinstance(m.layers[1], nn.ReLU)) self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) + nn.quantize(m, group_size=32, mode="mxfp4") + self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding)) + self.assertTrue(isinstance(m.layers[1], nn.ReLU)) + self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) + self.assertTrue(isinstance(m.layers[2].scales, mx.array)) + def test_quantize_freeze(self): lin = nn.Linear(512, 512) qlin = lin.to_quantized() diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 90a57221f..f22c0cae3 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -27,6 +27,56 @@ class TestQuantized(mlx_tests.MLXTestCase): a_hat = mx.dequantize(w_q, scales, biases, gs, b) self.assertTrue(mx.all(a_hat == 0)) + def test_mxfp4_quantize_dequantize(self): + lut = mx.array( + [ + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ] + ) + w = lut[mx.random.randint(0, 16, shape=(128, 512))] + w = w.reshape(-1, 32) + w[:, 0] = 6 + w = (w + 3e-6).astype(mx.bfloat16) + + # Invalid bits / group size + with self.assertRaises(ValueError): + mx.quantize(w, bits=3, group_size=32, mode="mxfp4") + + with self.assertRaises(ValueError): + mx.quantize(w, group_size=64, bits=4, mode="mxfp4") + + w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4") + + with self.assertRaises(ValueError): + mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4") + + with self.assertRaises(ValueError): + mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4") + + w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4") + self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5)) + + # test quantize/dequantize 0s + a = mx.zeros((256, 512)) + w_q, scales = mx.quantize(a, group_size=32, bits=4, mode="mxfp4") + w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4") + self.assertTrue(mx.all(w_hat == 0)) + def test_qmm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) @@ -168,6 +218,34 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_mxfp4_qmv(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + tests = product( + [256, 512, 67], # M + [64, 128], # N + [0, 1, 3, 8], # B + ) + for M, N, B in tests: + with self.subTest(shape=(B, M, N), group_size=32): + x_shape = (3, 1, N) if B == 0 else (B, 1, N) + w_shape = (M, N) if B == 0 else (B, M, N) + x = mx.random.normal(shape=x_shape, key=k1) + w = mx.random.normal(shape=w_shape, key=k2) + w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4") + w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4") + y_q = mx.quantized_matmul( + x, + w_q, + scales, + transpose=True, + group_size=32, + mode="mxfp4", + ) + y_hat = x @ mx.swapaxes(w_hat, -1, -2) + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_qvm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) @@ -233,6 +311,103 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 2e-3) + def test_mxfp4_qvm(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + tests = product( + [32, 128, 256], # M + [128, 256, 67], # N + [0, 1, 3, 8], # B + ) + # Add a splitk + tests = list(tests) + tests.append((128, 16384, 0)) + + for M, N, B in tests: + with self.subTest(shape=(B, M, N)): + x_shape = (1, N) if B == 0 else (B, 1, N) + w_shape = (N, M) if B == 0 else (B, N, M) + x = mx.random.normal(shape=x_shape, key=k1) + w = mx.random.normal(shape=w_shape, key=k2) + w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4") + w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4") + y_q = mx.quantized_matmul( + x, + w_q, + scales, + transpose=False, + group_size=32, + mode="mxfp4", + ) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 2e-3) + + def test_mode_error_cases(self): + w = mx.random.normal(shape=(256, 256)) + x = mx.random.normal(shape=(1, 256)) + + # Invalid mode + with self.assertRaises(ValueError): + mx.quantize(w, mode="xyz") + + wq, scales, biases = mx.quantize(w, bits=4, group_size=32) + + with self.assertRaises(ValueError): + mx.dequantize(wq, scales, biases, bits=4, group_size=32, mode="xyz") + + with self.assertRaises(ValueError): + mx.quantized_matmul( + x, wq, scales, biases, bits=4, group_size=32, mode="xyz" + ) + + rhs_indices = mx.array(0) + with self.assertRaises(ValueError): + mx.gather_qmm( + x, + wq, + scales, + biases, + rhs_indices=rhs_indices, + bits=4, + group_size=32, + mode="xyz", + ) + + # Only quantize floating point types + with self.assertRaises(ValueError): + mx.quantize(mx.zeros((128, 128), mx.int32)) + + with self.assertRaises(ValueError): + mx.quantize(mx.zeros((128, 128), mx.int32), mode="mxfp4") + + # Must have bias for affine + with self.assertRaises(ValueError): + mx.dequantize(wq, scales, None, bits=4, group_size=32) + + with self.assertRaises(ValueError): + mx.quantized_matmul(x, wq, scales, None, bits=4, group_size=32) + + with self.assertRaises(ValueError): + mx.gather_qmm( + x, wq, scales, None, rhs_indices=rhs_indices, bits=4, group_size=32 + ) + + # Must be floating point + x = mx.zeros(shape=(256,), dtype=mx.int32) + scales = mx.zeros(scales.shape, dtype=mx.int32) + biases = mx.zeros(scales.shape, dtype=mx.int32) + with self.assertRaises(ValueError): + mx.dequantize(wq, scales, biases, bits=4, group_size=32) + + with self.assertRaises(ValueError): + mx.quantized_matmul(x, wq, scales, biases, bits=4, group_size=32) + + with self.assertRaises(ValueError): + mx.gather_qmm( + x, wq, scales, biases, rhs_indices=rhs_indices, bits=4, group_size=32 + ) + def test_throw(self): x = mx.random.normal(shape=(10, 512)) w = mx.random.normal(shape=(32, 512)) @@ -360,9 +535,13 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertLess((y_q - y_hat).abs().max(), 1e-3) def test_gather_qmm(self): - def quantize(w, transpose=True, group_size=64, bits=4): - qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) - w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) + def quantize(w, transpose=True, group_size=64, bits=4, mode="affine"): + if mode == "affine": + qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode) + else: + qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode) + b = None + w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode) if transpose: w_hat = w_hat.swapaxes(-1, -2) return w_hat, qw, s, b @@ -379,6 +558,7 @@ class TestQuantized(mlx_tests.MLXTestCase): transpose=True, group_size=64, bits=4, + mode="affine", ): with self.subTest( M=M, @@ -392,12 +572,13 @@ class TestQuantized(mlx_tests.MLXTestCase): transpose=transpose, group_size=group_size, bits=bits, + mode=mode, ): x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype) w = mx.random.normal( shape=batch_B + ((N, K) if transpose else (K, N)) ).astype(dtype) - w_hat, qw, s, b = quantize(w, transpose, group_size, bits) + w_hat, qw, s, b = quantize(w, transpose, group_size, bits, mode=mode) if lhs_indices is not None: lhs_indices = mx.array(lhs_indices) @@ -415,8 +596,8 @@ class TestQuantized(mlx_tests.MLXTestCase): transpose=transpose, group_size=group_size, bits=bits, + mode=mode, ) - self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) inputs = ( @@ -460,6 +641,14 @@ class TestQuantized(mlx_tests.MLXTestCase): "batch_B": (4, 1), "rhs_indices": ((2,), (0,), (1,)), }, + { + "batch_A": (1,), + "lhs_indices": (0,), + "batch_B": (3,), + "rhs_indices": (2, 1), + "group_size": 32, + "mode": "mxfp4", + }, ) for kwargs in inputs: @@ -503,9 +692,14 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(g1, g2, atol=1e-4)) def test_gather_qmm_sorted(self): - def quantize(w, transpose=True, group_size=64, bits=4): - qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) - w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) + def quantize(w, transpose=True, bits=4, group_size=64, mode="affine"): + if mode == "affine": + qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode) + else: + qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode) + b = None + + w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode) if transpose: w_hat = w_hat.swapaxes(-1, -2) return w_hat, qw, s, b @@ -525,19 +719,23 @@ class TestQuantized(mlx_tests.MLXTestCase): parameters = [ # L, K, D, E, I, transpose - (32, 512, 512, 4, 2, True), - (32, 512, 544, 4, 2, True), - (133, 512, 512, 4, 2, True), - (133, 512, 555, 4, 2, True), - (133, 512, 512, 4, 2, True), - (64, 512, 512, 4, 2, False), - (64, 512, 544, 4, 2, False), - (133, 512, 512, 4, 2, False), - (133, 512, 544, 4, 2, False), - (133, 512, 555, 4, 2, False), - (64, 512, 512, 4, 2, False), + (32, 512, 512, 4, 2, True, "affine"), + (32, 512, 544, 4, 2, True, "mxfp4"), + (133, 512, 512, 4, 2, True, "affine"), + (133, 512, 555, 4, 2, True, "affine"), + (133, 512, 512, 4, 2, True, "affine"), + (64, 512, 512, 4, 2, False, "affine"), + (64, 512, 544, 4, 2, False, "mxfp4"), + (133, 512, 512, 4, 2, False, "affine"), + (133, 512, 544, 4, 2, False, "affine"), + (133, 512, 555, 4, 2, False, "affine"), + (64, 512, 512, 4, 2, False, "affine"), ] - for L, K, D, E, I, transpose in parameters: + for L, K, D, E, I, transpose, mode in parameters: + if mode == "mxfp4": + group_size = 32 + else: + group_size = 64 K, D = (K, D) if transpose else (D, K) ishape = (L, I) xshape = (L, 1, 1, K) @@ -546,14 +744,28 @@ class TestQuantized(mlx_tests.MLXTestCase): indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32) x = mx.random.normal(xshape) / K**0.5 w = mx.random.normal(wshape) / K**0.5 - w, *wq = quantize(w, transpose=transpose) + w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose) y1 = mx.gather_mm(x, w, rhs_indices=indices) - y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices) + y2 = mx.gather_qmm( + x, + *wq, + group_size=group_size, + mode=mode, + transpose=transpose, + rhs_indices=indices + ) xs, idx, inv_order = gather_sort(x, indices) y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True) + y4 = mx.gather_qmm( - xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True + xs, + *wq, + group_size=group_size, + mode=mode, + rhs_indices=idx, + transpose=transpose, + sorted_indices=True ) y3 = scatter_unsort(y3, inv_order, indices.shape) y4 = scatter_unsort(y4, inv_order, indices.shape) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 17207efd4..878c7101b 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2996,7 +2996,10 @@ TEST_CASE("test quantize dequantize") { for (int i = 2; i <= 8; i *= 2) { int el_per_int = 32 / i; - auto [x_q, scales, biases] = quantize(x, 128, i); + auto res = quantize(x, 128, i); + auto x_q = res[0]; + auto scales = res[1]; + auto biases = res[2]; CHECK_EQ(x_q.shape(), Shape{128, 512 / el_per_int}); CHECK_EQ(scales.shape(), Shape{128, 4}); CHECK_EQ(biases.shape(), Shape{128, 4});