From e78a6518faa6bdb42eb18039522fab6d58fb9b02 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 16 May 2024 15:24:14 -0700 Subject: [PATCH] Block sparse qmm (#1124) --- mlx/backend/accelerate/primitives.cpp | 1 + mlx/backend/common/default_primitives.cpp | 1 + mlx/backend/common/quantized.cpp | 118 ++- mlx/backend/metal/kernels/quantized.metal | 868 ++++++++++++++++++++-- mlx/backend/metal/quantized.cpp | 291 +++++++- mlx/backend/no_metal/primitives.cpp | 1 + mlx/ops.cpp | 263 ++++--- mlx/ops.h | 13 + mlx/primitives.cpp | 80 +- mlx/primitives.h | 28 + python/mlx/nn/layers/embedding.py | 5 + python/mlx/nn/layers/linear.py | 5 + python/mlx/nn/layers/quantized.py | 24 +- python/src/ops.cpp | 48 +- python/tests/test_quantized.py | 142 ++++ 15 files changed, 1724 insertions(+), 164 deletions(-) diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 9bf1868c2..5a2500c64 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -33,6 +33,7 @@ DEFAULT(ArgSort) DEFAULT(AsStrided) DEFAULT(BlockMaskedMM) DEFAULT(BlockSparseMM) +DEFAULT(BlockSparseQMM) DEFAULT(Broadcast) DEFAULT(Ceil) DEFAULT(Concatenate) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index ec5289d6a..4ebb27af2 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -44,6 +44,7 @@ DEFAULT(AsStrided) DEFAULT(Broadcast) DEFAULT(BlockMaskedMM) DEFAULT(BlockSparseMM) +DEFAULT(BlockSparseQMM) DEFAULT_MULTI(DivMod) DEFAULT(Ceil) DEFAULT(Concatenate) diff --git a/mlx/backend/common/quantized.cpp b/mlx/backend/common/quantized.cpp index be55cca29..4dfb1780e 100644 --- a/mlx/backend/common/quantized.cpp +++ b/mlx/backend/common/quantized.cpp @@ -192,7 +192,7 @@ void _qmm_dispatch_typed( } void _qmm_dispatch( - array out, + array& out, const array& x, const array& w, const array& scales, @@ -253,6 +253,81 @@ void _qmm_dispatch( } } +void _bs_qmm_dispatch( + array& out, + const array& x, + const array& w, + const array& scales, + const array& biases, + const array& lhs_indices, + const array& rhs_indices, + int bits, + int group_size, + bool transposed_w) { + int K = x.shape(-1); + int M = x.shape(-2); + int N = out.shape(-1); + + int w_els = w.shape(-1) * w.shape(-2); + int g_els = scales.shape(-1) * scales.shape(-2); + + const uint32_t* lhs_indices_data = lhs_indices.data(); + const uint32_t* rhs_indices_data = rhs_indices.data(); + + for (int i = 0; i < lhs_indices.size(); i++) { + int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)]; + int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)]; + + switch (x.dtype()) { + case float32: + _qmm_dispatch_typed( + out.data() + i * M * N, + x.data() + elem_to_loc(x_idx * M * K, x), + w.data() + elem_to_loc(w_idx * w_els, w), + scales.data() + elem_to_loc(w_idx * g_els, scales), + biases.data() + elem_to_loc(w_idx * g_els, biases), + M, + N, + K, + bits, + group_size, + transposed_w); + break; + case float16: + _qmm_dispatch_typed( + out.data() + i * M * N, + x.data() + elem_to_loc(x_idx * M * K, x), + w.data() + elem_to_loc(w_idx * w_els, w), + scales.data() + elem_to_loc(w_idx * g_els, scales), + biases.data() + elem_to_loc(w_idx * g_els, biases), + M, + N, + K, + bits, + group_size, + transposed_w); + break; + case bfloat16: + _qmm_dispatch_typed( + out.data() + i * M * N, + x.data() + elem_to_loc(x_idx * M * K, x), + w.data() + elem_to_loc(w_idx * w_els, w), + scales.data() + elem_to_loc(w_idx * g_els, scales), + biases.data() + elem_to_loc(w_idx * g_els, biases), + M, + N, + K, + bits, + group_size, + transposed_w); + break; + default: + throw std::invalid_argument( + "[quantized_matmul] only floating types are supported"); + } + } +} + } // namespace void QuantizedMatmul::eval(const std::vector& inputs, array& out) { @@ -282,4 +357,45 @@ void QuantizedMatmul::eval(const std::vector& inputs, array& out) { _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); } +void BlockSparseQMM::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 6); + + auto& x_pre = inputs[0]; + auto& w_pre = inputs[1]; + auto& scales_pre = inputs[2]; + auto& biases_pre = inputs[3]; + auto& lhs_indices = inputs[4]; + auto& rhs_indices = inputs[5]; + + auto ensure_row_contiguous_last_dims = [](const array& arr) { + auto stride_0 = arr.strides()[arr.ndim() - 2]; + auto stride_1 = arr.strides()[arr.ndim() - 1]; + if (stride_0 == arr.shape(-1) && stride_1 == 1) { + return arr; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::General); + return arr_copy; + } + }; + + auto x = ensure_row_contiguous_last_dims(x_pre); + auto w = ensure_row_contiguous_last_dims(w_pre); + auto scales = ensure_row_contiguous_last_dims(scales_pre); + auto biases = ensure_row_contiguous_last_dims(biases_pre); + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + _bs_qmm_dispatch( + out, + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + group_size_, + bits_, + transpose_); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 58f165faa..61fe3c303 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -378,14 +378,14 @@ struct QuantizedBlockLoader { }; template -[[kernel]] void qmv_fast( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], +METAL_FUNC void qmv_fast_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -404,13 +404,13 @@ template // Adjust positions const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; w += out_row * in_vec_size_w + simd_lid * packs_per_thread; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - x += tid.z * in_vec_size + simd_lid * values_per_thread; - y += tid.z * out_vec_size + out_row; + x += tid.y * in_vec_size + simd_lid * values_per_thread; + y += tid.y * out_vec_size + out_row; for (int k = 0; k < in_vec_size; k += block_size) { U sum = load_vector(x, x_thread); @@ -440,15 +440,15 @@ template } } -template -[[kernel]] void qmv( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], +template +METAL_FUNC void qmv_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -468,7 +468,7 @@ template // Adjust positions const int in_vec_size_w = in_vec_size / pack_factor; const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); @@ -482,8 +482,8 @@ template w += out_row * in_vec_size_w + simd_lid * packs_per_thread; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - x += tid.z * in_vec_size + simd_lid * values_per_thread; - y += tid.z * out_vec_size + out_row; + x += tid.y * in_vec_size + simd_lid * values_per_thread; + y += tid.y * out_vec_size + out_row; int k = 0; for (; k < in_vec_size - block_size; k += block_size) { @@ -537,8 +537,8 @@ template w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread; scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - x += tid.z * in_vec_size + simd_lid * values_per_thread; - y += tid.z * out_vec_size + used_out_row; + x += tid.y * in_vec_size + simd_lid * values_per_thread; + y += tid.y * out_vec_size + used_out_row; int k = 0; for (; k < in_vec_size - block_size; k += block_size) { @@ -590,14 +590,14 @@ template } template -[[kernel]] void qvm( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], +METAL_FUNC void qvm_impl( + const device T* x, + const device uint32_t* w, + const device T* scales, + const device T* biases, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -616,12 +616,12 @@ template // Adjust positions const int out_vec_size_w = out_vec_size / pack_factor; const int out_vec_size_g = out_vec_size / group_size; - int out_col = tid.y * (num_simdgroups * pack_factor) + simd_gid * pack_factor; + int out_col = tid.x * (num_simdgroups * pack_factor) + simd_gid * pack_factor; w += out_col / pack_factor; scales += out_col / group_size; biases += out_col / group_size; - x += tid.z * in_vec_size; - y += tid.z * out_vec_size + out_col; + x += tid.y * in_vec_size; + y += tid.y * out_vec_size + out_col; if (out_col >= out_vec_size) { return; @@ -675,15 +675,17 @@ template < const int group_size, const int bits, const bool aligned_N> -[[kernel]] void qmm_t( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& M [[buffer(5)]], - const constant int& N [[buffer(6)]], - const constant int& K [[buffer(7)]], +METAL_FUNC void qmm_t_impl( + const device T* x, + const device uint32_t* w, + const device T* scales, + const device T* biases, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& M, + const constant int& N, + const constant int& K, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -713,9 +715,6 @@ template < group_size, bits>; - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BN * BK_padded]; - // Set the block const int K_w = K / pack_factor; const int K_g = K / group_size; @@ -797,15 +796,17 @@ template < const int BN, const int group_size, const int bits> -[[kernel]] void qmm_n( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& M [[buffer(5)]], - const constant int& N [[buffer(6)]], - const constant int& K [[buffer(7)]], +METAL_FUNC void qmm_n_impl( + const device T* x, + const device uint32_t* w, + const device T* scales, + const device T* biases, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& M, + const constant int& N, + const constant int& K, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -836,9 +837,6 @@ template < group_size, bits>; - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BK * BN_padded]; - // Set the block const int y_row = tid.y * BM; const int y_col = tid.x * BN; @@ -923,6 +921,518 @@ template < } } +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant size_t* lhs_strides, + const constant size_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant size_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant size_t* w_strides, + const constant size_t* s_strides, + const constant size_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + +template +[[kernel]] void qmv_fast( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + qmv_fast_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void qmv( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + qmv_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void qvm( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + qvm_impl( + x, + w, + scales, + biases, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template < + typename T, + const int BM, + const int BK, + const int BN, + const int group_size, + const int bits, + const bool aligned_N> +[[kernel]] void qmm_t( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& M [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& K [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + + qmm_t_impl( + x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int BM, + const int BK, + const int BN, + const int group_size, + const int bits> +[[kernel]] void qmm_n( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& M [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& K [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + qmm_n_impl( + x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void bs_qmv_fast( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& batch_ndims [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* lhs_strides [[buffer(11)]], + const constant size_t* rhs_strides [[buffer(12)]], + const constant int& x_batch_ndims [[buffer(13)]], + const constant int* x_shape [[buffer(14)]], + const constant size_t* x_strides [[buffer(15)]], + const constant int& w_batch_ndims [[buffer(16)]], + const constant int* w_shape [[buffer(17)]], + const constant size_t* w_strides [[buffer(18)]], + const constant size_t* s_strides [[buffer(19)]], + const constant size_t* b_strides [[buffer(20)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + out_vec_size, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmv_fast_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void bs_qmv( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& batch_ndims [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* lhs_strides [[buffer(11)]], + const constant size_t* rhs_strides [[buffer(12)]], + const constant int& x_batch_ndims [[buffer(13)]], + const constant int* x_shape [[buffer(14)]], + const constant size_t* x_strides [[buffer(15)]], + const constant int& w_batch_ndims [[buffer(16)]], + const constant int* w_shape [[buffer(17)]], + const constant size_t* w_strides [[buffer(18)]], + const constant size_t* s_strides [[buffer(19)]], + const constant size_t* b_strides [[buffer(20)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + out_vec_size, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmv_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void bs_qvm( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& batch_ndims [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant size_t* lhs_strides [[buffer(11)]], + const constant size_t* rhs_strides [[buffer(12)]], + const constant int& x_batch_ndims [[buffer(13)]], + const constant int* x_shape [[buffer(14)]], + const constant size_t* x_strides [[buffer(15)]], + const constant int& w_batch_ndims [[buffer(16)]], + const constant int* w_shape [[buffer(17)]], + const constant size_t* w_strides [[buffer(18)]], + const constant size_t* s_strides [[buffer(19)]], + const constant size_t* b_strides [[buffer(20)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + out_vec_size, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qvm_impl( + x, + w, + scales, + biases, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template < + typename T, + const int BM, + const int BK, + const int BN, + const int group_size, + const int bits, + const bool aligned_N> +[[kernel]] void bs_qmm_t( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& K [[buffer(9)]], + const constant int& batch_ndims [[buffer(10)]], + const constant int* batch_shape [[buffer(11)]], + const constant size_t* lhs_strides [[buffer(12)]], + const constant size_t* rhs_strides [[buffer(13)]], + const constant int& x_batch_ndims [[buffer(14)]], + const constant int* x_shape [[buffer(15)]], + const constant size_t* x_strides [[buffer(16)]], + const constant int& w_batch_ndims [[buffer(17)]], + const constant int* w_shape [[buffer(18)]], + const constant size_t* w_strides [[buffer(19)]], + const constant size_t* s_strides [[buffer(20)]], + const constant size_t* b_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_t_impl( + x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int BM, + const int BK, + const int BN, + const int group_size, + const int bits> +[[kernel]] void bs_qmm_n( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& K [[buffer(9)]], + const constant int& batch_ndims [[buffer(10)]], + const constant int* batch_shape [[buffer(11)]], + const constant size_t* lhs_strides [[buffer(12)]], + const constant size_t* rhs_strides [[buffer(13)]], + const constant int& x_batch_ndims [[buffer(14)]], + const constant int* x_shape [[buffer(15)]], + const constant size_t* x_strides [[buffer(16)]], + const constant int& w_batch_ndims [[buffer(17)]], + const constant int* w_shape [[buffer(18)]], + const constant size_t* w_strides [[buffer(19)]], + const constant size_t* s_strides [[buffer(20)]], + const constant size_t* b_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_n_impl( + x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); +} + #define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \ template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits \ "_fast")]] [[kernel]] void \ @@ -1089,3 +1599,241 @@ instantiate_qmm_n_types( 64, 8) instantiate_qmm_n_types( 32, 2) instantiate_qmm_n_types( 32, 4) instantiate_qmm_n_types( 32, 8) // clang-format on + +#define instantiate_bs_qmv_fast( \ + name, itype, group_size, bits, packs_per_thread) \ + template [[host_name("bs_qmv_" #name "_gs_" #group_size "_b_" #bits \ + "_fast")]] [[kernel]] void \ + bs_qmv_fast( \ + const device uint32_t* w [[buffer(0)]], \ + const device itype* scales [[buffer(1)]], \ + const device itype* biases [[buffer(2)]], \ + const device itype* x [[buffer(3)]], \ + const device uint32_t* lhs_indices [[buffer(4)]], \ + const device uint32_t* rhs_indices [[buffer(5)]], \ + device itype* y [[buffer(6)]], \ + const constant int& in_vec_size [[buffer(7)]], \ + const constant int& out_vec_size [[buffer(8)]], \ + const constant int& batch_ndims [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* lhs_strides [[buffer(11)]], \ + const constant size_t* rhs_strides [[buffer(12)]], \ + const constant int& x_batch_ndims [[buffer(13)]], \ + const constant int* x_shape [[buffer(14)]], \ + const constant size_t* x_strides [[buffer(15)]], \ + const constant int& w_batch_ndims [[buffer(16)]], \ + const constant int* w_shape [[buffer(17)]], \ + const constant size_t* w_strides [[buffer(18)]], \ + const constant size_t* s_strides [[buffer(19)]], \ + const constant size_t* b_strides [[buffer(20)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +// clang-format off +#define instantiate_bs_qmv_fast_types(group_size, bits, packs_per_thread) \ + instantiate_bs_qmv_fast(float32, float, group_size, bits, packs_per_thread) \ + instantiate_bs_qmv_fast(float16, half, group_size, bits, packs_per_thread) \ + instantiate_bs_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread) // clang-format on + + // clang-format off +instantiate_bs_qmv_fast_types(128, 2, 1) +instantiate_bs_qmv_fast_types(128, 4, 2) +instantiate_bs_qmv_fast_types(128, 8, 2) +instantiate_bs_qmv_fast_types( 64, 2, 1) +instantiate_bs_qmv_fast_types( 64, 4, 2) +instantiate_bs_qmv_fast_types( 64, 8, 2) +instantiate_bs_qmv_fast_types( 32, 2, 1) +instantiate_bs_qmv_fast_types( 32, 4, 2) +instantiate_bs_qmv_fast_types( 32, 8, 2) // clang-format on + +#define instantiate_bs_qmv(name, itype, group_size, bits) \ + template [[host_name("bs_qmv_" #name "_gs_" #group_size \ + "_b_" #bits)]] [[kernel]] void \ + bs_qmv( \ + const device uint32_t* w [[buffer(0)]], \ + const device itype* scales [[buffer(1)]], \ + const device itype* biases [[buffer(2)]], \ + const device itype* x [[buffer(3)]], \ + const device uint32_t* lhs_indices [[buffer(4)]], \ + const device uint32_t* rhs_indices [[buffer(5)]], \ + device itype* y [[buffer(6)]], \ + const constant int& in_vec_size [[buffer(7)]], \ + const constant int& out_vec_size [[buffer(8)]], \ + const constant int& batch_ndims [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* lhs_strides [[buffer(11)]], \ + const constant size_t* rhs_strides [[buffer(12)]], \ + const constant int& x_batch_ndims [[buffer(13)]], \ + const constant int* x_shape [[buffer(14)]], \ + const constant size_t* x_strides [[buffer(15)]], \ + const constant int& w_batch_ndims [[buffer(16)]], \ + const constant int* w_shape [[buffer(17)]], \ + const constant size_t* w_strides [[buffer(18)]], \ + const constant size_t* s_strides [[buffer(19)]], \ + const constant size_t* b_strides [[buffer(20)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +// clang-format off +#define instantiate_bs_qmv_types(group_size, bits) \ + instantiate_bs_qmv(float32, float, group_size, bits) \ + instantiate_bs_qmv(float16, half, group_size, bits) \ + instantiate_bs_qmv(bfloat16, bfloat16_t, group_size, bits) // clang-format on + + // clang-format off +instantiate_bs_qmv_types(128, 2) +instantiate_bs_qmv_types(128, 4) +instantiate_bs_qmv_types(128, 8) +instantiate_bs_qmv_types( 64, 2) +instantiate_bs_qmv_types( 64, 4) +instantiate_bs_qmv_types( 64, 8) +instantiate_bs_qmv_types( 32, 2) +instantiate_bs_qmv_types( 32, 4) +instantiate_bs_qmv_types( 32, 8) // clang-format on + +#define instantiate_bs_qvm(name, itype, group_size, bits) \ + template [[host_name("bs_qvm_" #name "_gs_" #group_size \ + "_b_" #bits)]] [[kernel]] void \ + bs_qvm( \ + const device itype* x [[buffer(0)]], \ + const device uint32_t* w [[buffer(1)]], \ + const device itype* scales [[buffer(2)]], \ + const device itype* biases [[buffer(3)]], \ + const device uint32_t* lhs_indices [[buffer(4)]], \ + const device uint32_t* rhs_indices [[buffer(5)]], \ + device itype* y [[buffer(6)]], \ + const constant int& in_vec_size [[buffer(7)]], \ + const constant int& out_vec_size [[buffer(8)]], \ + const constant int& batch_ndims [[buffer(9)]], \ + const constant int* batch_shape [[buffer(10)]], \ + const constant size_t* lhs_strides [[buffer(11)]], \ + const constant size_t* rhs_strides [[buffer(12)]], \ + const constant int& x_batch_ndims [[buffer(13)]], \ + const constant int* x_shape [[buffer(14)]], \ + const constant size_t* x_strides [[buffer(15)]], \ + const constant int& w_batch_ndims [[buffer(16)]], \ + const constant int* w_shape [[buffer(17)]], \ + const constant size_t* w_strides [[buffer(18)]], \ + const constant size_t* s_strides [[buffer(19)]], \ + const constant size_t* b_strides [[buffer(20)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +// clang-format off +#define instantiate_bs_qvm_types(group_size, bits) \ + instantiate_bs_qvm(float32, float, group_size, bits) \ + instantiate_bs_qvm(float16, half, group_size, bits) \ + instantiate_bs_qvm(bfloat16, bfloat16_t, group_size, bits) // clang-format on + + // clang-format off +instantiate_bs_qvm_types(128, 2) +instantiate_bs_qvm_types(128, 4) +instantiate_bs_qvm_types(128, 8) +instantiate_bs_qvm_types( 64, 2) +instantiate_bs_qvm_types( 64, 4) +instantiate_bs_qvm_types( 64, 8) +instantiate_bs_qvm_types( 32, 2) +instantiate_bs_qvm_types( 32, 4) +instantiate_bs_qvm_types( 32, 8) // clang-format on + +#define instantiate_bs_qmm_t(name, itype, group_size, bits, aligned_N) \ + template [[host_name("bs_qmm_t_" #name "_gs_" #group_size "_b_" #bits \ + "_alN_" #aligned_N)]] [[kernel]] void \ + bs_qmm_t( \ + const device itype* x [[buffer(0)]], \ + const device uint32_t* w [[buffer(1)]], \ + const device itype* scales [[buffer(2)]], \ + const device itype* biases [[buffer(3)]], \ + const device uint32_t* lhs_indices [[buffer(4)]], \ + const device uint32_t* rhs_indices [[buffer(5)]], \ + device itype* y [[buffer(6)]], \ + const constant int& M [[buffer(7)]], \ + const constant int& N [[buffer(8)]], \ + const constant int& K [[buffer(9)]], \ + const constant int& batch_ndims [[buffer(10)]], \ + const constant int* batch_shape [[buffer(11)]], \ + const constant size_t* lhs_strides [[buffer(12)]], \ + const constant size_t* rhs_strides [[buffer(13)]], \ + const constant int& x_batch_ndims [[buffer(14)]], \ + const constant int* x_shape [[buffer(15)]], \ + const constant size_t* x_strides [[buffer(16)]], \ + const constant int& w_batch_ndims [[buffer(17)]], \ + const constant int* w_shape [[buffer(18)]], \ + const constant size_t* w_strides [[buffer(19)]], \ + const constant size_t* s_strides [[buffer(20)]], \ + const constant size_t* b_strides [[buffer(21)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint lid [[thread_index_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +// clang-format off +#define instantiate_bs_qmm_t_types(group_size, bits) \ + instantiate_bs_qmm_t(float32, float, group_size, bits, false) \ + instantiate_bs_qmm_t(float16, half, group_size, bits, false) \ + instantiate_bs_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \ + instantiate_bs_qmm_t(float32, float, group_size, bits, true) \ + instantiate_bs_qmm_t(float16, half, group_size, bits, true) \ + instantiate_bs_qmm_t(bfloat16, bfloat16_t, group_size, bits, true) // clang-format on + + // clang-format off +instantiate_bs_qmm_t_types(128, 2) +instantiate_bs_qmm_t_types(128, 4) +instantiate_bs_qmm_t_types(128, 8) +instantiate_bs_qmm_t_types( 64, 2) +instantiate_bs_qmm_t_types( 64, 4) +instantiate_bs_qmm_t_types( 64, 8) +instantiate_bs_qmm_t_types( 32, 2) +instantiate_bs_qmm_t_types( 32, 4) +instantiate_bs_qmm_t_types( 32, 8) // clang-format on + +#define instantiate_bs_qmm_n(name, itype, group_size, bits) \ + template [[host_name("bs_qmm_n_" #name "_gs_" #group_size \ + "_b_" #bits)]] [[kernel]] void \ + bs_qmm_n( \ + const device itype* x [[buffer(0)]], \ + const device uint32_t* w [[buffer(1)]], \ + const device itype* scales [[buffer(2)]], \ + const device itype* biases [[buffer(3)]], \ + const device uint32_t* lhs_indices [[buffer(4)]], \ + const device uint32_t* rhs_indices [[buffer(5)]], \ + device itype* y [[buffer(6)]], \ + const constant int& M [[buffer(7)]], \ + const constant int& N [[buffer(8)]], \ + const constant int& K [[buffer(9)]], \ + const constant int& batch_ndims [[buffer(10)]], \ + const constant int* batch_shape [[buffer(11)]], \ + const constant size_t* lhs_strides [[buffer(12)]], \ + const constant size_t* rhs_strides [[buffer(13)]], \ + const constant int& x_batch_ndims [[buffer(14)]], \ + const constant int* x_shape [[buffer(15)]], \ + const constant size_t* x_strides [[buffer(16)]], \ + const constant int& w_batch_ndims [[buffer(17)]], \ + const constant int* w_shape [[buffer(18)]], \ + const constant size_t* w_strides [[buffer(19)]], \ + const constant size_t* s_strides [[buffer(20)]], \ + const constant size_t* b_strides [[buffer(21)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint lid [[thread_index_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +// clang-format off +#define instantiate_bs_qmm_n_types(group_size, bits) \ + instantiate_bs_qmm_n(float32, float, group_size, bits) \ + instantiate_bs_qmm_n(float16, half, group_size, bits) \ + instantiate_bs_qmm_n(bfloat16, bfloat16_t, group_size, bits) // clang-format on + + // clang-format off +instantiate_bs_qmm_n_types(128, 2) +instantiate_bs_qmm_n_types(128, 4) +instantiate_bs_qmm_n_types(128, 8) +instantiate_bs_qmm_n_types( 64, 2) +instantiate_bs_qmm_n_types( 64, 4) +instantiate_bs_qmm_n_types( 64, 8) +instantiate_bs_qmm_n_types( 32, 2) +instantiate_bs_qmm_n_types( 32, 4) +instantiate_bs_qmm_n_types( 32, 8) // clang-format on diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 4f48f9ce8..980509f47 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -55,7 +55,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { int bo = 8; int bd = 32; MTL::Size group_dims = MTL::Size(bd, 2, 1); - MTL::Size grid_dims = MTL::Size(1, O / bo, B); + MTL::Size grid_dims = MTL::Size(O / bo, B, 1); compute_encoder.set_input_array(w, 0); compute_encoder.set_input_array(scales, 1); @@ -82,7 +82,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { int bo = 8; int bd = 32; MTL::Size group_dims = MTL::Size(bd, 2, 1); - MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B); + MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1); compute_encoder.set_input_array(w, 0); compute_encoder.set_input_array(scales, 1); @@ -140,7 +140,7 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { int bo = 8; int bd = 32; MTL::Size group_dims = MTL::Size(bd, bo, 1); - MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B); + MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1); compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(w, 1); @@ -196,4 +196,289 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); } +void BlockSparseQMM::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 6); + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& x_pre = inputs[0]; + auto& w_pre = inputs[1]; + auto& scales_pre = inputs[2]; + auto& biases_pre = inputs[3]; + auto& lhs_indices = inputs[4]; + auto& rhs_indices = inputs[5]; + + // TODO: collapse batch dims + auto& batch_shape = lhs_indices.shape(); + int batch_ndims = batch_shape.size(); + auto& lhs_strides = lhs_indices.strides(); + auto& rhs_strides = rhs_indices.strides(); + + // Ensure that the last two dims are row contiguous. + // TODO: Check if we really need this for x as well... + std::vector copies; + auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) { + auto stride_0 = arr.strides()[arr.ndim() - 2]; + auto stride_1 = arr.strides()[arr.ndim() - 1]; + if (stride_0 == arr.shape(-1) && stride_1 == 1) { + return arr; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + return arr_copy; + } + }; + auto x = ensure_row_contiguous_last_dims(x_pre); + auto w = ensure_row_contiguous_last_dims(w_pre); + auto scales = ensure_row_contiguous_last_dims(scales_pre); + auto biases = ensure_row_contiguous_last_dims(biases_pre); + + int x_batch_ndims = x.ndim() - 2; + auto& x_shape = x.shape(); + auto& x_strides = x.strides(); + int w_batch_ndims = w.ndim() - 2; + auto& w_shape = w.shape(); + auto& w_strides = w.strides(); + auto& s_strides = scales.strides(); + auto& b_strides = biases.strides(); + + int D = x.shape(-1); + int B = x.shape(-2); + int O = out.shape(-1); + int N = out.size() / B / O; + if (transpose_) { + // Route to the fast bs_qmv kernel that has no bounds checking + if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { + std::ostringstream kname; + kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" + << bits_ << "_fast"; + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int bo = 8; + int bd = 32; + MTL::Size group_dims = MTL::Size(bd, 2, 1); + MTL::Size grid_dims = MTL::Size(O / bo, B, N); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_input_array(lhs_indices, 4); + compute_encoder.set_input_array(rhs_indices, 5); + compute_encoder.set_output_array(out, 6); + compute_encoder->setBytes(&D, sizeof(int), 7); + compute_encoder->setBytes(&O, sizeof(int), 8); + + compute_encoder->setBytes(&batch_ndims, sizeof(int), 9); + set_vector_bytes(compute_encoder, batch_shape, 10); + set_vector_bytes(compute_encoder, lhs_strides, 11); + set_vector_bytes(compute_encoder, rhs_strides, 12); + + compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13); + set_vector_bytes(compute_encoder, x_shape, 14); + set_vector_bytes(compute_encoder, x_strides, 15); + compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16); + set_vector_bytes(compute_encoder, w_shape, 17); + set_vector_bytes(compute_encoder, w_strides, 18); + set_vector_bytes(compute_encoder, s_strides, 19); + set_vector_bytes(compute_encoder, b_strides, 20); + + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + } + + else if (B < 6) { + std::ostringstream kname; + kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" + << bits_; + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int bo = 8; + int bd = 32; + MTL::Size group_dims = MTL::Size(bd, 2, 1); + MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_input_array(lhs_indices, 4); + compute_encoder.set_input_array(rhs_indices, 5); + compute_encoder.set_output_array(out, 6); + compute_encoder->setBytes(&D, sizeof(int), 7); + compute_encoder->setBytes(&O, sizeof(int), 8); + + compute_encoder->setBytes(&batch_ndims, sizeof(int), 9); + set_vector_bytes(compute_encoder, batch_shape, 10); + set_vector_bytes(compute_encoder, lhs_strides, 11); + set_vector_bytes(compute_encoder, rhs_strides, 12); + + compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13); + set_vector_bytes(compute_encoder, x_shape, 14); + set_vector_bytes(compute_encoder, x_strides, 15); + compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16); + set_vector_bytes(compute_encoder, w_shape, 17); + set_vector_bytes(compute_encoder, w_strides, 18); + set_vector_bytes(compute_encoder, s_strides, 19); + set_vector_bytes(compute_encoder, b_strides, 20); + + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + } + + // Route to the bs_qmm_t + else { + std::ostringstream kname; + kname << "bs_qmm_t_" << type_to_name(out) << "_gs_" << group_size_ + << "_b_" << bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0); + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int wn = 2; + int wm = 2; + int bm = 32; + int bn = 32; + int bk = 32; + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N); + + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(w, 1); + compute_encoder.set_input_array(scales, 2); + compute_encoder.set_input_array(biases, 3); + compute_encoder.set_input_array(lhs_indices, 4); + compute_encoder.set_input_array(rhs_indices, 5); + compute_encoder.set_output_array(out, 6); + compute_encoder->setBytes(&B, sizeof(int), 7); + compute_encoder->setBytes(&O, sizeof(int), 8); + compute_encoder->setBytes(&D, sizeof(int), 9); + + compute_encoder->setBytes(&batch_ndims, sizeof(int), 10); + set_vector_bytes(compute_encoder, batch_shape, 11); + set_vector_bytes(compute_encoder, lhs_strides, 12); + set_vector_bytes(compute_encoder, rhs_strides, 13); + + compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14); + set_vector_bytes(compute_encoder, x_shape, 15); + set_vector_bytes(compute_encoder, x_strides, 16); + compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17); + set_vector_bytes(compute_encoder, w_shape, 18); + set_vector_bytes(compute_encoder, w_strides, 19); + set_vector_bytes(compute_encoder, s_strides, 20); + set_vector_bytes(compute_encoder, b_strides, 21); + + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + } + } else { + // Route to the bs_qvm kernel + if (B < 4) { + std::ostringstream kname; + kname << "bs_qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" + << bits_; + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int bo = 8; + int bd = 32; + MTL::Size group_dims = MTL::Size(bd, bo, 1); + MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N); + + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(w, 1); + compute_encoder.set_input_array(scales, 2); + compute_encoder.set_input_array(biases, 3); + compute_encoder.set_input_array(lhs_indices, 4); + compute_encoder.set_input_array(rhs_indices, 5); + compute_encoder.set_output_array(out, 6); + compute_encoder->setBytes(&D, sizeof(int), 7); + compute_encoder->setBytes(&O, sizeof(int), 8); + + compute_encoder->setBytes(&batch_ndims, sizeof(int), 9); + set_vector_bytes(compute_encoder, batch_shape, 10); + set_vector_bytes(compute_encoder, lhs_strides, 11); + set_vector_bytes(compute_encoder, rhs_strides, 12); + + compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13); + set_vector_bytes(compute_encoder, x_shape, 14); + set_vector_bytes(compute_encoder, x_strides, 15); + compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16); + set_vector_bytes(compute_encoder, w_shape, 17); + set_vector_bytes(compute_encoder, w_strides, 18); + set_vector_bytes(compute_encoder, s_strides, 19); + set_vector_bytes(compute_encoder, b_strides, 20); + + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + } + + // Route to bs_qmm_n + else { + std::ostringstream kname; + kname << "bs_qmm_n_" << type_to_name(out) << "_gs_" << group_size_ + << "_b_" << bits_; + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int wn = 2; + int wm = 2; + int bm = 32; + int bn = 32; + int bk = 32; + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N); + + if ((O % bn) != 0) { + std::ostringstream msg; + msg << "[quantized_matmul] The output size should be divisible by " + << bn << " but received " << O << "."; + throw std::runtime_error(msg.str()); + } + + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(w, 1); + compute_encoder.set_input_array(scales, 2); + compute_encoder.set_input_array(biases, 3); + compute_encoder.set_input_array(lhs_indices, 4); + compute_encoder.set_input_array(rhs_indices, 5); + compute_encoder.set_output_array(out, 6); + compute_encoder->setBytes(&B, sizeof(int), 7); + compute_encoder->setBytes(&O, sizeof(int), 8); + compute_encoder->setBytes(&D, sizeof(int), 9); + + compute_encoder->setBytes(&batch_ndims, sizeof(int), 10); + set_vector_bytes(compute_encoder, batch_shape, 11); + set_vector_bytes(compute_encoder, lhs_strides, 12); + set_vector_bytes(compute_encoder, rhs_strides, 13); + + compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14); + set_vector_bytes(compute_encoder, x_shape, 15); + set_vector_bytes(compute_encoder, x_strides, 16); + compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17); + set_vector_bytes(compute_encoder, w_shape, 18); + set_vector_bytes(compute_encoder, w_strides, 19); + set_vector_bytes(compute_encoder, s_strides, 20); + set_vector_bytes(compute_encoder, b_strides, 21); + + compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + } + } +} + } // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 63114d386..9934336cb 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -35,6 +35,7 @@ NO_GPU(AsStrided) NO_GPU(BitwiseBinary) NO_GPU(BlockMaskedMM) NO_GPU(BlockSparseMM) +NO_GPU(BlockSparseQMM) NO_GPU(Broadcast) NO_GPU(Ceil) NO_GPU_MULTI(Compiled) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 83be15384..8f2062a6d 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -50,6 +50,83 @@ Dtype at_least_float(const Dtype& d) { return issubdtype(d, inexact) ? d : promote_types(d, float32); } +array indices_or_default( + std::optional indices, + const array& x, + StreamOrDevice s) { + if (indices.has_value()) { + return indices.value(); + } + + std::vector shape(x.shape().begin(), x.shape().end() - 2); + int total = + std::reduce(shape.begin(), shape.end(), 1, std::multiplies()); + return reshape(arange(total, uint32, s), shape, s); +} + +std::pair extract_quantized_matmul_dims( + std::string_view tag, + const array& x, + const array& w, + const array& scales, + const array& biases, + bool transpose, + int group_size, + int bits) { + if (w.dtype() != uint32) { + std::ostringstream msg; + msg << "[" << tag << "] The weight matrix should be uint32 " + << "but received" << w.dtype(); + throw std::invalid_argument(msg.str()); + } + + if (scales.shape() != biases.shape()) { + std::ostringstream msg; + msg << "[" << tag << "] Scales and biases should have the same shape. " + << "Received scales with shape " << scales.shape() + << " and biases with " << biases.shape(); + throw std::invalid_argument(msg.str()); + } + + if (!std::equal( + w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) { + std::ostringstream msg; + msg << "[" << tag + << "] Weight, scales and biases should have the same batch shape. " + << "Received weight with shape " << w.shape() << ", scales with " + << scales.shape() << " and biases with " << biases.shape(); + throw std::invalid_argument(msg.str()); + } + + if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) { + std::ostringstream msg; + msg << "[" << tag << "] The shapes of the weight and scales are " + << "incompatible based on bits and group_size. w.shape() == " + << w.shape() << " and scales.shape() == " << scales.shape() + << " with group_size=" << group_size << " and bits=" << bits; + throw std::invalid_argument(msg.str()); + } + + int x_inner_dims = x.shape(-1); + + // Calculate the expanded w's dims + int w_inner_dims = (transpose) ? w.shape(-1) * 32 / bits : w.shape(-2); + int w_outer_dims = (transpose) ? w.shape(-2) : w.shape(-1) * 32 / bits; + + if (w_inner_dims != x_inner_dims) { + std::ostringstream msg; + msg << "[" << tag << "] Last dimension of first input with " + << "shape (..., " << x_inner_dims << ") does not match " + << "the expanded quantized matrix (" << w_inner_dims << ", " + << w_outer_dims << ") computed from shape " << w.shape() + << " with group_size=" << group_size << ", bits=" << bits + << " and transpose=" << std::boolalpha << transpose; + throw std::invalid_argument(msg.str()); + } + + return {w_inner_dims, w_outer_dims}; +} + } // namespace array arange( @@ -3203,7 +3280,7 @@ array conv_general( } array quantized_matmul( - const array& in_x, + const array& x, const array& w, const array& scales, const array& biases, @@ -3211,13 +3288,10 @@ array quantized_matmul( int group_size /* = 64 */, int bits /* = 4 */, StreamOrDevice s /* = {} */) { - array x = in_x; - if (w.dtype() != uint32) { - std::ostringstream msg; - msg << "[quantized_matmul] The weight matrix should be uint32 " - << "but received" << w.dtype(); - throw std::invalid_argument(msg.str()); - } + // Check and extract the quantized matrix shape against x + auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( + "quantized_matmul", x, w, scales, biases, transpose, group_size, bits); + if (w.ndim() != 2) { std::ostringstream msg; msg << "[quantized_matmul] Batched quantized matmul is not supported for now " @@ -3225,42 +3299,6 @@ array quantized_matmul( throw std::invalid_argument(msg.str()); } - // Keep x's batch dimensions to reshape it back after the matmul - auto original_shape = x.shape(); - int x_inner_dims = original_shape.back(); - - if (scales.ndim() != 2 || scales.shape() != biases.shape()) { - std::ostringstream msg; - msg << "[quantized_matmul] Scales and biases should have the same 2D shape. " - << "Received scales with shape " << scales.shape() - << " and biases with " << biases.shape(); - throw std::invalid_argument(msg.str()); - } - - if (w.shape(1) * 32 / bits != scales.shape(1) * group_size) { - std::ostringstream msg; - msg << "[quantized_matmul] The shapes of the weight and scales are " - << "incompatible based on bits and group_size. w.shape() == " - << w.shape() << " and scales.shape() == " << scales.shape() - << " with group_size=" << group_size << " and bits=" << bits; - throw std::invalid_argument(msg.str()); - } - - // Calculate the expanded w's dims - int w_inner_dims = (transpose) ? w.shape(1) * 32 / bits : w.shape(0); - int w_outer_dims = (transpose) ? w.shape(0) : w.shape(1) * 32 / bits; - - if (w_inner_dims != x_inner_dims) { - std::ostringstream msg; - msg << "[quantized_matmul] Last dimension of first input with " - << "shape (..., " << x_inner_dims << ") does not match " - << "the expanded quantized matrix (" << w_inner_dims << ", " - << w_outer_dims << ") computed from shape " << w.shape() - << " with group_size=" << group_size << ", bits=" << bits - << " and transpose=" << std::boolalpha << transpose; - throw std::invalid_argument(msg.str()); - } - auto dtype = result_type(x, scales, biases); if (!issubdtype(dtype, floating)) { std::ostringstream msg; @@ -3270,10 +3308,11 @@ array quantized_matmul( << " and biases.dtype() == " << biases.dtype(); throw std::invalid_argument(msg.str()); } - std::vector inputs; - original_shape.back() = w_outer_dims; + + auto out_shape = x.shape(); + out_shape.back() = w_outer_dims; return array( - std::move(original_shape), + std::move(out_shape), dtype, std::make_shared( to_stream(s), group_size, bits, transpose), @@ -3302,11 +3341,14 @@ std::tuple quantize( throw std::invalid_argument(msg.str()); } - if (w.ndim() != 2) { - throw std::invalid_argument("[quantize] Only matrices supported for now"); + if (w.ndim() < 2) { + std::ostringstream msg; + msg << "[quantize] The matrix to be quantized must have at least 2 dimension " + << "but it has only " << w.ndim() << "."; + throw std::invalid_argument(msg.str()); } - if ((w.shape(1) % group_size) != 0) { + if ((w.shape(-1) % group_size) != 0) { std::ostringstream msg; msg << "[quantize] The last dimension of the matrix needs to be divisible by " << "the quantization group size " << group_size @@ -3327,7 +3369,7 @@ std::tuple quantize( // at least we bail out early which will result in a nice readable error. // // Hopefully nobody is quantizing matrices that small anyway. - if (w.shape(1) < 32 * el_per_int) { + if (w.shape(-1) < 32 * el_per_int) { std::ostringstream msg; msg << "[quantize] The feature dimension (2nd dimension of the matrix) is " << "too small for quantization. We support >=512 for 2 bits, " @@ -3336,9 +3378,12 @@ std::tuple quantize( throw std::invalid_argument(msg.str()); } + // Prepare the shape for the outputs. + auto wshape = w.shape(); + wshape.back() = -1; + // Compute scales and biases - array packed_w = - reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s); + array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s); array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); @@ -3357,12 +3402,14 @@ std::tuple quantize( zero, n_bins), uint32); - packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s); + packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s); packed_w = sum( multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); return std::make_tuple( - packed_w, squeeze(scales, -1, s), squeeze(biases, -1, s)); + reshape(packed_w, wshape, s), + reshape(scales, wshape, s), + reshape(biases, wshape, s)); } array dequantize( @@ -3382,11 +3429,21 @@ array dequantize( msg << "[dequantize] Invalid value for group_size: " << group_size; throw std::invalid_argument(msg.str()); } - if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) { - throw std::invalid_argument("[dequantize] Only matrices supported for now"); + if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) { + std::ostringstream msg; + msg << "[quantize] The matrix to be quantized must have at least 2 dimension " + << "but it has only " << w.ndim() << "."; + throw std::invalid_argument(msg.str()); } - if (w.shape(0) != scales.shape(0) || w.shape(0) != biases.shape(0)) { + auto wshape = w.shape(); + auto sshape = scales.shape(); + auto bshape = biases.shape(); + wshape.back() = -1; + sshape.back() = -1; + bshape.back() = -1; + + if (wshape != sshape || wshape != bshape) { throw std::invalid_argument( "[dequantize] Shape of scales and biases does not match the matrix"); } @@ -3399,7 +3456,7 @@ array dequantize( // Compute some constants for the dequantization int el_per_int = 32 / bits; - if (w.shape(1) * el_per_int != scales.shape(1) * group_size) { + if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) { std::ostringstream msg; msg << "[dequantize] Shape of scales and biases does not match the matrix " << "given the quantization parameters. Provided matrix of shape " @@ -3411,25 +3468,79 @@ array dequantize( // Extract the pieces from the passed quantized matrix std::vector parts; for (int start = 0; start < 32; start += bits) { - // TODO: Implement bitwise operators for integral types int shift_left = 32 - (start + bits); int shift_right = shift_left + start; - array p = multiply(w, array(1 << shift_left, uint32), s); - p = floor_divide(p, array(1 << shift_right, uint32), s); - p = expand_dims(p, -1, s); - parts.push_back(p); + + parts.push_back(expand_dims( + right_shift( + left_shift(w, array(32 - (start + bits), uint32), s), + array(32 - bits, uint32), + s), + -1, + s)); } array w_full = concatenate(parts, -1, s); // Dequantize - w_full = reshape(w_full, {w.shape(0), -1, group_size}, s); + wshape.push_back(group_size); + w_full = reshape(w_full, wshape, s); w_full = multiply(w_full, expand_dims(scales, -1, s), s); w_full = add(w_full, expand_dims(biases, -1, s), s); - w_full = reshape(w_full, {w.shape(0), -1}, s); + w_full = reshape(w_full, sshape, s); return w_full; } +array block_sparse_qmm( + const array& x, + const array& w, + const array& scales, + const array& biases, + std::optional lhs_indices_ /* = std::nullopt */, + std::optional rhs_indices_ /* = std::nullopt */, + bool transpose /* = true */, + int group_size /* = 64 */, + int bits /* = 4 */, + StreamOrDevice s /* = {} */) { + if (!lhs_indices_ && !rhs_indices_) { + return quantized_matmul( + x, w, scales, biases, transpose, group_size, bits, s); + } + + auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( + "block_sparse_qmm", x, w, scales, biases, transpose, group_size, bits); + + // Extract indices and broadcast them + array lhs_indices = indices_or_default(lhs_indices_, x, s); + array rhs_indices = indices_or_default(rhs_indices_, w, s); + auto out_bsx_shape = + broadcast_shapes(lhs_indices.shape(), rhs_indices.shape()); + lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s); + rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s); + + // Compute the full output shape + auto out_shape = out_bsx_shape; + out_shape.push_back(x.shape(-2)); + out_shape.push_back(w_outer_dims); + + // and output type + auto out_type = result_type(x, scales, biases); + + auto out = array( + std::move(out_shape), + out_type, + std::make_shared( + to_stream(s), group_size, bits, transpose), + {astype(x, out_type, s), + w, + astype(scales, out_type, s), + astype(biases, out_type, s), + lhs_indices, + rhs_indices}); + + return out; +} + array tensordot( const array& a, const array& b, @@ -3879,24 +3990,8 @@ array block_sparse_mm( b = astype(b, out_type, s); // Handle broadcasting - std::vector bsx_a(a.shape().begin(), a.shape().end() - 2); - std::vector bsx_b(b.shape().begin(), b.shape().end() - 2); - - auto indices_or_default = [&](const std::optional& indices, - const std::vector& bsx_shape) { - if (indices.has_value()) { - return indices.value(); - } else { - int n_batch = 1; - for (auto& i : bsx_shape) - n_batch *= i; - return reshape(arange(n_batch, uint32, s), bsx_shape, s); - } - }; - - // Pull and broadcast indices - array lhs_indices = indices_or_default(lhs_indices_, bsx_a); - array rhs_indices = indices_or_default(rhs_indices_, bsx_b); + array lhs_indices = indices_or_default(lhs_indices_, a, s); + array rhs_indices = indices_or_default(rhs_indices_, b, s); if (!issubdtype(lhs_indices.dtype(), integer)) { throw std::invalid_argument( diff --git a/mlx/ops.h b/mlx/ops.h index c43437f13..c334024ac 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1157,6 +1157,19 @@ array dequantize( int bits = 4, StreamOrDevice s = {}); +/** Compute matrix products with matrix-level gather. */ +array block_sparse_qmm( + const array& x, + const array& w, + const array& scales, + const array& biases, + std::optional lhs_indices = std::nullopt, + std::optional rhs_indices = std::nullopt, + bool transpose = true, + int group_size = 64, + int bits = 4, + StreamOrDevice s = {}); + /** Returns a contraction of a and b over multiple dimensions. */ array tensordot( const array& a, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index b5f144384..89454c035 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2372,7 +2372,85 @@ std::vector QuantizedMatmul::jvp( bool QuantizedMatmul::is_equivalent(const Primitive& other) const { const QuantizedMatmul& qm_other = static_cast(other); - return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_; + return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && + transpose_ == qm_other.transpose_; +} + +std::pair, std::vector> BlockSparseQMM::vmap( + const std::vector& inputs, + const std::vector& axes) { + throw std::runtime_error("BlockSparseQMM::vmap NYI"); +} + +std::vector BlockSparseQMM::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + std::vector vjps; + + auto& cotan = cotangents[0]; + + auto& x = primals[0]; + auto& w = primals[1]; + auto& scales = primals[2]; + auto& biases = primals[3]; + auto& lhs_indices = primals[4]; + auto& rhs_indices = primals[5]; + + for (auto arg : argnums) { + // gradient wrt to x + if (arg == 0) { + vjps.push_back(reshape( + scatter_add( + flatten(zeros_like(x, stream()), 0, -3, stream()), + lhs_indices, + expand_dims( + block_sparse_qmm( + cotan, + w, + scales, + biases, + std::nullopt, + rhs_indices, + !transpose_, + group_size_, + bits_, + stream()), + -3, + stream()), + 0, + stream()), + x.shape(), + stream())); + } + + // gradient wrt to the indices is undefined + else if (arg > 3) { + throw std::runtime_error( + "BlockSparseQMM::vjp cannot compute the gradient wrt the indices."); + } + + // gradient wrt to w_q, scales or biases + else { + throw std::runtime_error( + "BlockSparseQMM::vjp no gradient wrt the quantized matrix yet."); + } + } + return vjps; +} + +std::vector BlockSparseQMM::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + throw std::runtime_error("BlockSparseQMM::jvp NYI"); +} + +bool BlockSparseQMM::is_equivalent(const Primitive& other) const { + const BlockSparseQMM& qm_other = static_cast(other); + return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && + transpose_ == qm_other.transpose_; } std::pair, std::vector> RandomBits::vmap( diff --git a/mlx/primitives.h b/mlx/primitives.h index 868b5e7f5..dff21a072 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1467,6 +1467,34 @@ class QuantizedMatmul : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class BlockSparseQMM : public UnaryPrimitive { + public: + explicit BlockSparseQMM( + Stream stream, + int group_size, + int bits, + bool transpose) + : UnaryPrimitive(stream), + group_size_(group_size), + bits_(bits), + transpose_(transpose) {}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(BlockSparseQMM) + bool is_equivalent(const Primitive& other) const override; + + private: + int group_size_; + int bits_; + bool transpose_; + + void eval(const std::vector& inputs, array& out); +}; + class RandomBits : public UnaryPrimitive { public: explicit RandomBits(Stream stream, const std::vector& shape, int width) diff --git a/python/mlx/nn/layers/embedding.py b/python/mlx/nn/layers/embedding.py index a8327a280..1e15a59cc 100644 --- a/python/mlx/nn/layers/embedding.py +++ b/python/mlx/nn/layers/embedding.py @@ -4,6 +4,7 @@ import math import mlx.core as mx from mlx.nn.layers.base import Module +from mlx.nn.layers.quantized import QuantizedEmbedding class Embedding(Module): @@ -37,3 +38,7 @@ class Embedding(Module): weights are tied. """ return x @ self.weight.T + + def to_quantized(self, group_size: int = 64, bits: int = 4): + """Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer.""" + return QuantizedEmbedding.from_embedding(self, group_size, bits) diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 38eea7791..63caa911c 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -5,6 +5,7 @@ from typing import Any import mlx.core as mx from mlx.nn.layers.base import Module +from mlx.nn.layers.quantized import QuantizedLinear class Identity(Module): @@ -69,6 +70,10 @@ class Linear(Module): x = x @ self["weight"].T return x + def to_quantized(self, group_size: int = 64, bits: int = 4): + """Return a :obj:`QuantizedLinear` layer that approximates this layer.""" + return QuantizedLinear.from_linear(self, group_size, bits) + class Bilinear(Module): r"""Applies a bilinear transformation to the inputs. diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 192a28fd0..b8d727d88 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -5,8 +5,6 @@ from typing import Callable, Optional import mlx.core as mx from mlx.nn.layers.base import Module -from mlx.nn.layers.embedding import Embedding -from mlx.nn.layers.linear import Linear from mlx.utils import tree_map_with_path @@ -18,8 +16,9 @@ def quantize( ): """Quantize the sub-modules of a module according to a predicate. - By default all :obj:`Linear` and :obj:`Embedding` layers will be - quantized. Note also, the module is updated in-place. + By default all layers that define a ``to_quantized(group_size, bits)`` + method will be quantized. Both :obj:`Linear` and :obj:`Embedding` layers + will be quantized. Note also, the module is updated in-place. Args: model (mlx.nn.Module): The model whose leaf modules may be quantized. @@ -30,18 +29,15 @@ def quantize( class_predicate (Optional[Callable]): A callable which receives the :obj:`Module` path and :obj:`Module` itself and returns ``True`` if it should be quantized and ``False`` otherwise. If ``None``, then - all linear and embedding layers are quantized. Default: ``None``. + all layers that define a ``to_quantized(group_size, bits)`` method + are quantized. Default: ``None``. """ - class_predicate = class_predicate or ( - lambda _, m: isinstance(m, (Linear, Embedding)) - ) + class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized")) def _maybe_quantize(path, m): if class_predicate(path, m): - if isinstance(m, Linear): - return QuantizedLinear.from_linear(m, group_size, bits) - elif isinstance(m, Embedding): - return QuantizedEmbedding.from_embedding(m, group_size, bits) + if hasattr(m, "to_quantized"): + return m.to_quantized(group_size, bits) else: raise ValueError(f"Unable to quantize model of type {type(m)}") else: @@ -129,7 +125,7 @@ class QuantizedEmbedding(Module): @classmethod def from_embedding( - cls, embedding_layer: Embedding, group_size: int = 64, bits: int = 4 + cls, embedding_layer: Module, group_size: int = 64, bits: int = 4 ): """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" embedding_dims, dims = embedding_layer.weight.shape @@ -220,7 +216,7 @@ class QuantizedLinear(Module): return x @classmethod - def from_linear(cls, linear_layer: Linear, group_size: int = 64, bits: int = 4): + def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape ql = cls(input_dims, output_dims, False, group_size, bits) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 82c141e31..23d503d4b 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3747,6 +3747,52 @@ void init_ops(nb::module_& m) { Returns: result (array): The dequantized version of ``w`` )pbdoc"); + m.def( + "block_sparse_qmm", + &block_sparse_qmm, + nb::arg(), + nb::arg(), + "scales"_a, + "biases"_a, + "lhs_indices"_a = nb::none(), + "rhs_indices"_a = nb::none(), + "transpose"_a = true, + "group_size"_a = 64, + "bits"_a = 4, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def block_sparse_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Perform quantized matrix multiplication with matrix-level gather. + + This operation is the quantized equivalent to :func:`block_sparse_mm`. + Similar to :func:`block_sparse_mm`, the indices ``lhs_indices`` and + ``rhs_indices`` contain flat indices along the batch dimensions (i.e. + all but the last two dimensions) of ``x`` and ``w`` respectively. + + Note that ``scales`` and ``biases`` must have the same batch dimensions + as ``w`` since they represent the same quantized matrix. + + Args: + x (array): Input array + w (array): Quantized matrix packed in unsigned integers + scales (array): The scales to use per ``group_size`` elements of ``w`` + biases (array): The biases to use per ``group_size`` elements of ``w`` + lhs_indices (array, optional): Integer indices for ``x`` (default: ``None``) + rhs_indices (array, optional): Integer indices for ``w`` (default: ``None``) + transpose (bool, optional): Defines whether to multiply with the + transposed ``w`` or not, namely whether we are performing + ``x @ w.T`` or ``x @ w``. (default: ``True``) + group_size (int, optional): The size of the group in ``w`` that + shares a scale and bias. (default: ``64``) + bits (int, optional): The number of bits occupied by each element in + ``w``. (default: ``4``) + + Returns: + result (array): The result of the multiplication of ``x`` with ``w`` + after gathering using ``lhs_indices`` and ``rhs_indices``. + )pbdoc"); m.def( "tensordot", [](const array& a, @@ -3933,7 +3979,7 @@ void init_ops(nb::module_& m) { Matrix multiplication with matrix-level gather. Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays. - This operation is more efficient than explicitly applying a :func:``take`` followed by a :func:``matmul``. + This operation is more efficient than explicitly applying a :func:`take` followed by a :func:`matmul`. The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively. diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 2c214abbd..21dcc3103 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -277,6 +277,148 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_block_sparse_qmm(self): + def quantize(w, transpose=True, group_size=64, bits=4): + qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) + w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) + if transpose: + w_hat = w_hat.swapaxes(-1, -2) + return w_hat, qw, s, b + + def test_shape( + M, + N, + K, + dtype=mx.float32, + batch_A=(), + batch_B=(), + lhs_indices=None, + rhs_indices=None, + transpose=True, + group_size=64, + bits=4, + ): + with self.subTest( + M=M, + N=N, + K=K, + dtype=dtype, + batch_A=batch_A, + batch_B=batch_B, + lhs_indices=lhs_indices, + rhs_indices=rhs_indices, + transpose=transpose, + group_size=group_size, + bits=bits, + ): + x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype) + w = mx.random.normal( + shape=batch_B + ((N, K) if transpose else (K, N)) + ).astype(dtype) + w_hat, qw, s, b = quantize(w, transpose, group_size, bits) + + if lhs_indices is not None: + lhs_indices = mx.array(lhs_indices) + if rhs_indices is not None: + rhs_indices = mx.array(rhs_indices) + + c1 = mx.block_sparse_mm(x, w_hat, lhs_indices, rhs_indices) + c2 = mx.block_sparse_qmm( + x, + qw, + s, + b, + lhs_indices, + rhs_indices, + transpose=transpose, + group_size=group_size, + bits=bits, + ) + + self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) + + inputs = ( + { + "batch_A": (1,), + "lhs_indices": (0,), + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + { + "batch_A": (1,), + "lhs_indices": None, + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + { + "batch_A": (2,), + "lhs_indices": None, + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + { + "batch_A": (3,), + "lhs_indices": (0, 2), + "batch_B": (1,), + "rhs_indices": (0,), + }, + { + "batch_A": (5,), + "lhs_indices": (0, 2), + "batch_B": (3,), + "rhs_indices": (2, 1), + }, + { + "batch_A": (4, 2), + "lhs_indices": ( + (7, 6), + (5, 4), + (1, 2), + ), + "batch_B": (4, 1), + "rhs_indices": ((2,), (0,), (1,)), + }, + ) + + for kwargs in inputs: + test_shape(32, 32, 256, **kwargs) + test_shape(1, 32, 256, **kwargs) + test_shape(32, 256, 32, transpose=False, **kwargs) + test_shape(1, 256, 32, transpose=False, **kwargs) + test_shape(32, 32, 512, **kwargs) + test_shape(1, 32, 512, **kwargs) + test_shape(32, 512, 32, transpose=False, **kwargs) + test_shape(1, 512, 32, transpose=False, **kwargs) + + def test_block_sparse_matmul_grad(self): + def quantize(w, transpose=True, group_size=64, bits=4): + qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) + w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) + if transpose: + w_hat = w_hat.swapaxes(-1, -2) + return w_hat, qw, s, b + + lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32) + rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32) + + x = mx.random.normal((4, 2, 32, 256)) + w = mx.random.normal((4, 1, 32, 256)) + w_hat, qw, s, b = quantize(w) + + def f_ref(x, w, i1, i2): + return mx.block_sparse_mm(x, w, i1, i2).sum() + + def f_test(x, qw, s, b, i1, i2): + return mx.block_sparse_qmm(x, qw, s, b, i1, i2, transpose=True).sum() + + r1 = f_ref(x, w_hat, lhs_indices, rhs_indices) + r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices) + self.assertTrue(mx.allclose(r1, r2, atol=1e-4)) + + g1 = mx.grad(f_ref)(x, w_hat, lhs_indices, rhs_indices) + g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices) + self.assertTrue(mx.allclose(g1, g2, atol=1e-4)) + if __name__ == "__main__": unittest.main()