From 5de6d94a903c9bf315e5548b03ec5294e5169c1a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 17 Apr 2025 13:53:11 -0700 Subject: [PATCH] Gather qmm batched kernel and refactoring of quantized (#2078) --- benchmarks/python/gather_mm_bench.py | 2 +- benchmarks/python/gather_qmm_bench.py | 84 ++ mlx/backend/metal/jit_kernels.cpp | 39 + mlx/backend/metal/kernels.h | 15 + mlx/backend/metal/kernels/quantized.h | 451 ++++++-- mlx/backend/metal/kernels/quantized.metal | 33 +- mlx/backend/metal/matmul.cpp | 3 +- mlx/backend/metal/nojit_kernels.cpp | 17 + mlx/backend/metal/quantized.cpp | 1132 +++++++++++++++------ mlx/ops.cpp | 15 +- mlx/ops.h | 1 + mlx/primitives.cpp | 3 + mlx/primitives.h | 17 +- python/src/ops.cpp | 51 +- python/tests/test_quantized.py | 65 +- 15 files changed, 1479 insertions(+), 449 deletions(-) create mode 100644 benchmarks/python/gather_qmm_bench.py diff --git a/benchmarks/python/gather_mm_bench.py b/benchmarks/python/gather_mm_bench.py index 85ddb08a6..ffeb73487 100644 --- a/benchmarks/python/gather_mm_bench.py +++ b/benchmarks/python/gather_mm_bench.py @@ -1,4 +1,4 @@ -# Copyright © 2023-2024 Apple Inc. +# Copyright © 2025 Apple Inc. import mlx.core as mx from time_utils import time_fn diff --git a/benchmarks/python/gather_qmm_bench.py b/benchmarks/python/gather_qmm_bench.py new file mode 100644 index 000000000..17c06d57d --- /dev/null +++ b/benchmarks/python/gather_qmm_bench.py @@ -0,0 +1,84 @@ +# Copyright © 2025 Apple Inc. + +import mlx.core as mx +from time_utils import time_fn + +N = 1024 +D = 1024 +M = 1024 +E = 32 +I = 4 + + +def gather_sort(x, indices): + N, M = indices.shape + indices = indices.flatten() + order = mx.argsort(indices) + inv_order = mx.argsort(order) + return x.flatten(0, -3)[order // M], indices[order], inv_order + + +def scatter_unsort(x, inv_order, shape=None): + x = x[inv_order] + if shape is not None: + x = mx.unflatten(x, 0, shape) + return x + + +def gather_mm_simulate(x, w, indices): + x, idx, inv_order = gather_sort(x, indices) + for i in range(2): + y = mx.concatenate( + [ + mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True) + for i, j in enumerate(idx.tolist()) + ], + axis=0, + ) + x = y[:, None] + x = scatter_unsort(x, inv_order, indices.shape) + return x + + +def time_gather_qmm(): + x = mx.random.normal((N, 1, 1, D)) / 1024**0.5 + w1 = mx.random.normal((E, M, D)) / 1024**0.5 + w2 = mx.random.normal((E, D, M)) / 1024**0.5 + w1 = mx.quantize(w1) + w2 = mx.quantize(w2) + indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32) + sorted_indices = mx.sort(indices.flatten()).reshape(N, I) + mx.eval(x, w1, w2, indices, sorted_indices) + + def gather_mm(x, w1, w2, indices, sort): + idx = indices + inv_order = None + if sort: + x, idx, inv_order = gather_sort(x, indices) + x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort) + x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort) + if sort: + x = scatter_unsort(x, inv_order, indices.shape) + return x + + time_fn(gather_mm, x, w1, w2, indices, False) + time_fn(gather_mm, x, w1, w2, sorted_indices, False) + time_fn(gather_mm, x, w1, w2, indices, True) + + x = mx.random.normal((N * I, D)) / 1024**0.5 + w1 = mx.random.normal((M, D)) / 1024**0.5 + w2 = mx.random.normal((D, M)) / 1024**0.5 + w1 = mx.quantize(w1) + w2 = mx.quantize(w2) + mx.eval(x, w1, w2) + + def equivalent_matmul(x, w1, w2): + x = mx.quantized_matmul(x, *w1, transpose=True) + x = mx.quantized_matmul(x, *w2, transpose=True) + return x + + time_fn(equivalent_matmul, x, w1, w2) + + +if __name__ == "__main__": + time_gather_qmm() diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index c0a698a86..5206c9b54 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -752,4 +752,43 @@ MTL::ComputePipelineState* get_quantized_kernel( return d.get_kernel(kernel_name, lib); } +MTL::ComputePipelineState* get_gather_qmm_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& x, + int group_size, + int bits, + int bm, + int bn, + int bk, + int wm, + int wn, + bool transpose) { + const auto& lib_name = kernel_name; + auto lib = d.get_library(lib_name, [&]() { + std::string kernel_source; + concatenate( + kernel_source, + metal::utils(), + metal::gemm(), + metal::quantized(), + get_template_definition( + lib_name, + "gather_qmm_rhs", + get_type_string(x.dtype()), + group_size, + bits, + bm, + bn, + bk, + wm, + wn, + transpose)); + return kernel_source; + }); + return d.get_kernel(kernel_name, lib, hash_name, func_consts); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index ba5914140..6d8864385 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -224,6 +224,21 @@ MTL::ComputePipelineState* get_quantized_kernel( const std::string& kernel_name, const std::string& template_def); +MTL::ComputePipelineState* get_gather_qmm_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array& x, + int group_size, + int bits, + int bm, + int bn, + int bk, + int wm, + int wn, + bool transpose); + // Create a GPU kernel template definition for JIT compilation template std::string diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index af9d7860e..b2b0d8d8f 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -3,6 +3,10 @@ #include #include +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + using namespace metal; #define MLX_MTL_CONST static constant constexpr const @@ -1686,26 +1690,26 @@ template < } template -[[kernel]] void bs_qmv_fast( +[[kernel]] void gather_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - const constant int& x_batch_ndims [[buffer(7)]], - const constant int* x_shape [[buffer(8)]], - const constant int64_t* x_strides [[buffer(9)]], - const constant int& w_batch_ndims [[buffer(10)]], - const constant int* w_shape [[buffer(11)]], - const constant int64_t* w_strides [[buffer(12)]], - const constant int64_t* s_strides [[buffer(13)]], - const constant int64_t* b_strides [[buffer(14)]], - const constant int& batch_ndims [[buffer(15)]], - const constant int* batch_shape [[buffer(16)]], - const device uint32_t* lhs_indices [[buffer(17)]], - const device uint32_t* rhs_indices [[buffer(18)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& x_batch_ndims [[buffer(9)]], + const constant int* x_shape [[buffer(10)]], + const constant int64_t* x_strides [[buffer(11)]], + const constant int& w_batch_ndims [[buffer(12)]], + const constant int* w_shape [[buffer(13)]], + const constant int64_t* w_strides [[buffer(14)]], + const constant int64_t* s_strides [[buffer(15)]], + const constant int64_t* b_strides [[buffer(16)]], + const constant int& batch_ndims [[buffer(17)]], + const constant int* batch_shape [[buffer(18)]], const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], @@ -1748,26 +1752,26 @@ template } template -[[kernel]] void bs_qmv( +[[kernel]] void gather_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - const constant int& x_batch_ndims [[buffer(7)]], - const constant int* x_shape [[buffer(8)]], - const constant int64_t* x_strides [[buffer(9)]], - const constant int& w_batch_ndims [[buffer(10)]], - const constant int* w_shape [[buffer(11)]], - const constant int64_t* w_strides [[buffer(12)]], - const constant int64_t* s_strides [[buffer(13)]], - const constant int64_t* b_strides [[buffer(14)]], - const constant int& batch_ndims [[buffer(15)]], - const constant int* batch_shape [[buffer(16)]], - const device uint32_t* lhs_indices [[buffer(17)]], - const device uint32_t* rhs_indices [[buffer(18)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& x_batch_ndims [[buffer(9)]], + const constant int* x_shape [[buffer(10)]], + const constant int64_t* x_strides [[buffer(11)]], + const constant int& w_batch_ndims [[buffer(12)]], + const constant int* w_shape [[buffer(13)]], + const constant int64_t* w_strides [[buffer(14)]], + const constant int64_t* s_strides [[buffer(15)]], + const constant int64_t* b_strides [[buffer(16)]], + const constant int& batch_ndims [[buffer(17)]], + const constant int* batch_shape [[buffer(18)]], const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], @@ -1810,26 +1814,26 @@ template } template -[[kernel]] void bs_qvm( +[[kernel]] void gather_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - const constant int& x_batch_ndims [[buffer(7)]], - const constant int* x_shape [[buffer(8)]], - const constant int64_t* x_strides [[buffer(9)]], - const constant int& w_batch_ndims [[buffer(10)]], - const constant int* w_shape [[buffer(11)]], - const constant int64_t* w_strides [[buffer(12)]], - const constant int64_t* s_strides [[buffer(13)]], - const constant int64_t* b_strides [[buffer(14)]], - const constant int& batch_ndims [[buffer(15)]], - const constant int* batch_shape [[buffer(16)]], - const device uint32_t* lhs_indices [[buffer(17)]], - const device uint32_t* rhs_indices [[buffer(18)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& x_batch_ndims [[buffer(9)]], + const constant int* x_shape [[buffer(10)]], + const constant int64_t* x_strides [[buffer(11)]], + const constant int& w_batch_ndims [[buffer(12)]], + const constant int* w_shape [[buffer(13)]], + const constant int64_t* w_strides [[buffer(14)]], + const constant int64_t* s_strides [[buffer(15)]], + const constant int64_t* b_strides [[buffer(16)]], + const constant int& batch_ndims [[buffer(17)]], + const constant int* batch_shape [[buffer(18)]], const constant int64_t* lhs_strides [[buffer(19)]], const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], @@ -1879,27 +1883,27 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void bs_qmm_t( +[[kernel]] void gather_qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& K [[buffer(5)]], - const constant int& N [[buffer(6)]], - const constant int& M [[buffer(7)]], - const constant int& x_batch_ndims [[buffer(8)]], - const constant int* x_shape [[buffer(9)]], - const constant int64_t* x_strides [[buffer(10)]], - const constant int& w_batch_ndims [[buffer(11)]], - const constant int* w_shape [[buffer(12)]], - const constant int64_t* w_strides [[buffer(13)]], - const constant int64_t* s_strides [[buffer(14)]], - const constant int64_t* b_strides [[buffer(15)]], - const constant int& batch_ndims [[buffer(16)]], - const constant int* batch_shape [[buffer(17)]], - const device uint32_t* lhs_indices [[buffer(18)]], - const device uint32_t* rhs_indices [[buffer(19)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], @@ -1946,27 +1950,27 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void bs_qmm_n( +[[kernel]] void gather_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& K [[buffer(5)]], - const constant int& N [[buffer(6)]], - const constant int& M [[buffer(7)]], - const constant int& x_batch_ndims [[buffer(8)]], - const constant int* x_shape [[buffer(9)]], - const constant int64_t* x_strides [[buffer(10)]], - const constant int& w_batch_ndims [[buffer(11)]], - const constant int* w_shape [[buffer(12)]], - const constant int64_t* w_strides [[buffer(13)]], - const constant int64_t* s_strides [[buffer(14)]], - const constant int64_t* b_strides [[buffer(15)]], - const constant int& batch_ndims [[buffer(16)]], - const constant int* batch_shape [[buffer(17)]], - const device uint32_t* lhs_indices [[buffer(18)]], - const device uint32_t* rhs_indices [[buffer(19)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], const constant int64_t* lhs_strides [[buffer(20)]], const constant int64_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], @@ -2007,6 +2011,289 @@ template < w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } +template +METAL_FUNC void gemm_loop_aligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template < + bool rows_aligned, + bool cols_aligned, + bool transpose, + typename T, + typename mma_t, + typename loader_a_t, + typename loader_b_t> +METAL_FUNC void gemm_loop_unaligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations, + const short tgp_bm, + const short tgp_bn, + const short tgp_bk) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + if (rows_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(short2(tgp_bk, tgp_bm)); + } + if (cols_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe( + transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template +METAL_FUNC void gemm_loop_finalize( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const short2 tile_a, + const short2 tile_b) { + loader_a.load_safe(tile_a); + loader_b.load_safe(tile_b); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); +} + +template < + typename T, + int group_size, + int bits, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose> +[[kernel]] void gather_qmm_rhs( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* indices [[buffer(4)]], + device T* y [[buffer(5)]], + const constant int& M [[buffer(6)]], + const constant int& N [[buffer(7)]], + const constant int& K [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + + using mma_t = mlx::steel::BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + false, + transpose, + BK_padded, + transpose ? BK_padded : BN_padded>; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_x = short2(k_remain, tgp_bm); + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + biases += transpose ? y_col_long * K_g : y_col / group_size; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + biases + index * stride_s, + transpose ? K : N, + Ws, + simd_group_id, + simd_lane_id); + + // Matrices are all aligned check nothing + if (align_M && align_N) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } else { + // Tile aligned so check outside of the hot loop + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} + template [[kernel]] void affine_quantize( const device T* w [[buffer(0)]], diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 7af554437..11cd8421b 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -60,6 +60,20 @@ bits, \ split_k) +#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \ + func, \ + type, \ + group_size, \ + bits, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + transpose) + #define instantiate_quantized_batched_wrap(name, type, group_size, bits) \ instantiate_quantized_batched(name, type, group_size, bits, 1) \ instantiate_quantized_batched(name, type, group_size, bits, 0) @@ -73,14 +87,14 @@ #define instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized(affine_quantize, type, group_size, bits) \ instantiate_quantized(affine_dequantize, type, group_size, bits) \ - instantiate_quantized(bs_qmv_fast, type, group_size, bits) \ - instantiate_quantized(bs_qmv, type, group_size, bits) \ - instantiate_quantized(bs_qvm, type, group_size, bits) \ - instantiate_quantized(bs_qmm_n, type, group_size, bits) + instantiate_quantized(gather_qmv_fast, type, group_size, bits) \ + instantiate_quantized(gather_qmv, type, group_size, bits) \ + instantiate_quantized(gather_qvm, type, group_size, bits) \ + instantiate_quantized(gather_qmm_n, type, group_size, bits) #define instantiate_quantized_all_aligned(type, group_size, bits) \ - instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \ - instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \ + instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \ + instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \ instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \ @@ -96,12 +110,17 @@ instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \ instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32) +#define instantiate_quantized_all_rhs(type, group_size, bits) \ + instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \ + instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false) + #define instantiate_quantized_funcs(type, group_size, bits) \ instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized_all_batched(type, group_size, bits) \ instantiate_quantized_all_aligned(type, group_size, bits) \ instantiate_quantized_all_quad(type, group_size, bits) \ - instantiate_quantized_all_splitk(type, group_size, bits) + instantiate_quantized_all_splitk(type, group_size, bits) \ + instantiate_quantized_all_rhs(type, group_size, bits) #define instantiate_quantized_types(group_size, bits) \ instantiate_quantized_funcs(float, group_size, bits) \ diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 27369ad07..f55d20c9f 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -1908,8 +1908,7 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc(out.nbytes())); - // Extract shapes strides from inputs and copy in case of non-contiguous - // vectors. + // Extract shapes from inputs. int M = a.shape(-2); int N = b.shape(-1); int K = a.shape(-1); diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 292af6919..8da147971 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -269,4 +269,21 @@ MTL::ComputePipelineState* get_quantized_kernel( return d.get_kernel(kernel_name); } +MTL::ComputePipelineState* get_gather_qmm_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const array&, + int, + int, + int, + int, + int, + int, + int, + bool) { + return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 8d1d176c4..5b3ec027b 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -2,6 +2,7 @@ #include +#include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" @@ -14,93 +15,168 @@ namespace mlx::core { -void launch_qmm( - std::string name, - const std::vector& inputs, +namespace { + +inline array +ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { + if (!x.flags().row_contiguous) { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return x_copy; + } else { + return x; + } +} + +inline array ensure_row_contiguous_matrix( + const array& x, + metal::Device& d, + const Stream& s) { + auto stride_0 = x.strides()[x.ndim() - 2]; + auto stride_1 = x.strides()[x.ndim() - 1]; + if (stride_0 == x.shape(-1) && stride_1 == 1) { + return x; + } else { + array x_copy(x.shape(), x.dtype(), nullptr, {}); + copy_gpu(x, x_copy, CopyType::General, s); + d.add_temporary(x_copy, s.index); + return x_copy; + } +} + +inline int get_qmv_batch_limit(int D, int O, metal::Device& d) { + auto arch = d.get_architecture(); + auto arch_size = arch.back(); + auto arch_gen = arch.substr(arch.size() - 3, 2); + if (arch_gen == "13" || arch_gen == "14") { + switch (arch_size) { + case 'd': + if (D <= 2048 && O <= 2048) { + return 32; + } else if (D <= 4096 && O <= 4096) { + return 18; + } else { + return 12; + } + default: + if (D <= 2048 && O <= 2048) { + return 14; + } else if (D <= 4096 && O <= 4096) { + return 10; + } else { + return 6; + } + } + } else { + switch (arch_size) { + case 'd': + if (D <= 2048 && O <= 2048) { + return 32; + } else if (D <= 4096 && O <= 4096) { + return 18; + } else { + return 12; + } + default: + if (D <= 2048 && O <= 2048) { + return 18; + } else if (D <= 4096 && O <= 4096) { + return 12; + } else { + return 10; + } + } + } +} + +inline int add_strides_and_shapes( + CommandEncoder& compute_encoder, + bool skip, + const array& x, + const array& w, + const array& scales, + const array& biases, + int offset) { + if (skip) { + return 0; + } + + // TODO: Collapse batch dimensions + + int x_batch_ndims = x.ndim() - 2; + int w_batch_ndims = w.ndim() - 2; + compute_encoder.set_bytes(x_batch_ndims, offset); + compute_encoder.set_vector_bytes(x.shape(), offset + 1); + compute_encoder.set_vector_bytes(x.strides(), offset + 2); + compute_encoder.set_bytes(w_batch_ndims, offset + 3); + compute_encoder.set_vector_bytes(w.shape(), offset + 4); + compute_encoder.set_vector_bytes(w.strides(), offset + 5); + compute_encoder.set_vector_bytes(scales.strides(), offset + 6); + compute_encoder.set_vector_bytes(biases.strides(), offset + 7); + + return 8; +} + +inline int add_gather_strides_and_shapes( + CommandEncoder& compute_encoder, + const array& lhs_indices, + const array& rhs_indices, + int offset) { + auto [shape, strides] = collapse_contiguous_dims( + lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()}); + int ndims = shape.size(); + + compute_encoder.set_bytes(ndims, offset); + compute_encoder.set_vector_bytes(shape, offset + 1); + compute_encoder.set_vector_bytes(strides[0], offset + 2); + compute_encoder.set_vector_bytes(strides[1], offset + 3); + + return 4; +} + +} // namespace + +void qmv_quad( + const array& x, + const array& w, + const array& scales, + const array& biases, array& out, int group_size, int bits, - int D, - int O, - int B, + int M, int N, - MTL::Size& group_dims, - MTL::Size& grid_dims, - bool batched, - bool matrix, - bool gather, - bool aligned, - bool quad, + int K, + metal::Device& d, const Stream& s) { - auto& x_pre = inputs[0]; - auto& w_pre = inputs[1]; - auto& scales_pre = inputs[2]; - auto& biases_pre = inputs[3]; + int B = out.size() / M / N; - // Ensure that the last two dims are row contiguous. - // TODO: Check if we really need this for x as well... - std::vector copies; - auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) { - auto stride_0 = arr.strides()[arr.ndim() - 2]; - auto stride_1 = arr.strides()[arr.ndim() - 1]; - if (stride_0 == arr.shape(-1) && stride_1 == 1) { - return arr; - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - return arr_copy; - } - }; - auto x = ensure_row_contiguous_last_dims(x_pre); - auto w = ensure_row_contiguous_last_dims(w_pre); - auto scales = ensure_row_contiguous_last_dims(scales_pre); - auto biases = ensure_row_contiguous_last_dims(biases_pre); + constexpr int quads_per_simd = 8; + constexpr int results_per_quadgroup = 8; + int bn = quads_per_simd * results_per_quadgroup; + int simdgroup_size = 32; + MTL::Size group_dims(simdgroup_size, 1, 1); + MTL::Size grid_dims(M, (N + bn - 1) / bn, B); - int x_batch_ndims = x.ndim() - 2; - auto& x_shape = x.shape(); - auto& x_strides = x.strides(); - int w_batch_ndims = w.ndim() - 2; - auto& w_shape = w.shape(); - auto& w_strides = w.strides(); - auto& s_strides = scales.strides(); - auto& b_strides = biases.strides(); + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + "qmv_quad_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_d_", + K, + B > 1 ? "_batch_1" : "_batch_0"); + auto template_def = get_template_definition( + kname, "qmv_quad", type_string, group_size, bits, K, B > 1); - std::string aligned_n = (O % 32) == 0 ? "true" : "false"; - - std::ostringstream kname; - auto type_string = get_type_string(x.dtype()); - kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits; - if (quad) { - kname << "_d_" << D; - } - if (aligned) { - kname << "_alN_" << aligned_n; - } - if (!gather) { - kname << "_batch_" << batched; - } - - // Encode and dispatch kernel - std::string template_def; - if (quad) { - template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, D, batched); - } else if (aligned && !gather) { - template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, aligned_n, batched); - } else if (!gather && !aligned) { - template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, batched); - } else if (aligned && gather) { - template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, aligned_n); - } else { - template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits); - } - auto& d = metal::device(s.device); - auto kernel = get_quantized_kernel(d, kname.str(), template_def); + auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -109,90 +185,87 @@ void launch_qmm( compute_encoder.set_input_array(biases, 2); compute_encoder.set_input_array(x, 3); compute_encoder.set_output_array(out, 4); - compute_encoder.set_bytes(D, 5); - compute_encoder.set_bytes(O, 6); - - int offset = 7; - if (matrix) { - compute_encoder.set_bytes(B, 7); - offset += 1; - } - - if (batched || gather) { - compute_encoder.set_bytes(x_batch_ndims, offset); - compute_encoder.set_vector_bytes(x_shape, offset + 1); - compute_encoder.set_vector_bytes(x_strides, offset + 2); - compute_encoder.set_bytes(w_batch_ndims, offset + 3); - compute_encoder.set_vector_bytes(w_shape, offset + 4); - compute_encoder.set_vector_bytes(w_strides, offset + 5); - compute_encoder.set_vector_bytes(s_strides, offset + 6); - compute_encoder.set_vector_bytes(b_strides, offset + 7); - } - if (gather) { - auto& lhs_indices = inputs[4]; - auto& rhs_indices = inputs[5]; - - // TODO: collapse batch dims - auto& batch_shape = lhs_indices.shape(); - int batch_ndims = batch_shape.size(); - auto& lhs_strides = lhs_indices.strides(); - auto& rhs_strides = rhs_indices.strides(); - - compute_encoder.set_bytes(batch_ndims, offset + 8); - compute_encoder.set_vector_bytes(batch_shape, offset + 9); - compute_encoder.set_input_array(lhs_indices, offset + 10); - compute_encoder.set_input_array(rhs_indices, offset + 11); - compute_encoder.set_vector_bytes(lhs_strides, offset + 12); - compute_encoder.set_vector_bytes(rhs_strides, offset + 13); - } + compute_encoder.set_bytes(K, 5); + compute_encoder.set_bytes(N, 6); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - d.add_temporaries(std::move(copies), s.index); } -void qvm_split_k( - const std::vector& inputs, +void qmv( + const array& x, + const array& w, + const array& scales, + const array& biases, array& out, int group_size, int bits, - int D, - int O, - int B, + int M, int N, + int K, + metal::Device& d, const Stream& s) { - int split_k = D > 8192 ? 32 : 8; - int split_D = (D + split_k - 1) / split_k; - N *= split_k; + int B = out.size() / M / N; - int bo = 64; - int bd = 32; - MTL::Size group_dims = MTL::Size(bd, 2, 1); - MTL::Size grid_dims = MTL::Size(B, O / bo, N); + int bn = 8; + int bk = 32; + MTL::Size group_dims(bk, 2, 1); + MTL::Size grid_dims(M, (N + bn - 1) / bn, B); - auto& x_pre = inputs[0]; - auto& w_pre = inputs[1]; - auto& scales_pre = inputs[2]; - auto& biases_pre = inputs[3]; + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + bool fast = N % bn == 0 && K % 512 == 0; + concatenate( + kname, + fast ? "qmv_fast_" : "qmv_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + B > 1 ? "_batch_1" : "_batch_0"); + auto template_def = get_template_definition( + kname, fast ? "qmv_fast" : "qmv", type_string, group_size, bits, B > 1); - // Ensure that the last two dims are row contiguous. - // TODO: Check if we really need this for x as well... - std::vector copies; - auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) { - auto stride_0 = arr.strides()[arr.ndim() - 2]; - auto stride_1 = arr.strides()[arr.ndim() - 1]; - if (stride_0 == arr.shape(-1) && stride_1 == 1) { - return arr; - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - return arr_copy; - } - }; - auto x = ensure_row_contiguous_last_dims(x_pre); - auto w = ensure_row_contiguous_last_dims(w_pre); - auto scales = ensure_row_contiguous_last_dims(scales_pre); - auto biases = ensure_row_contiguous_last_dims(biases_pre); + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_output_array(out, 4); + compute_encoder.set_bytes(K, 5); + compute_encoder.set_bytes(N, 6); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void qvm_split_k( + const array& x, + const array& w, + const array& scales, + const array& biases, + array& out, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int split_k = K > 8192 ? 32 : 8; + int split_D = (K + split_k - 1) / split_k; + int B = out.size() / M / N; + B *= split_k; + + int bn = 64; + int bk = 32; + MTL::Size group_dims = MTL::Size(bk, 2, 1); + MTL::Size grid_dims = MTL::Size(M, N / bn, B); int x_batch_ndims = x.ndim() - 2; auto x_shape = x.shape(); @@ -217,9 +290,7 @@ void qvm_split_k( s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1)); b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1)); - int final_block_size = D - (split_k - 1) * split_D; - - auto& d = metal::device(s.device); + int final_block_size = K - (split_k - 1) * split_D; auto temp_shape = out.shape(); temp_shape.insert(temp_shape.end() - 2, split_k); @@ -227,15 +298,24 @@ void qvm_split_k( intermediate.set_data(allocator::malloc(intermediate.nbytes())); d.add_temporary(intermediate, s.index); - std::ostringstream kname; - auto type_string = get_type_string(x.dtype()); - kname << "qvm_split_k" << "_" << type_string << "_gs_" << group_size << "_b_" - << bits << "_spk_" << split_k; + std::string type_string = get_type_string(x.dtype()); + std::string kname; + kname.reserve(64); + concatenate( + kname, + "qvm_split_k_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_spk_", + split_k); auto template_def = get_template_definition( - kname.str(), "qvm_split_k", type_string, group_size, bits, split_k); + kname, "qvm_split_k", type_string, group_size, bits, split_k); // Encode and dispatch kernel - auto kernel = get_quantized_kernel(d, kname.str(), template_def); + auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -245,7 +325,7 @@ void qvm_split_k( compute_encoder.set_input_array(x, 3); compute_encoder.set_output_array(intermediate, 4); compute_encoder.set_bytes(split_D, 5); - compute_encoder.set_bytes(O, 6); + compute_encoder.set_bytes(N, 6); compute_encoder.set_bytes(x_batch_ndims, 7); compute_encoder.set_vector_bytes(x_shape, 8); @@ -258,7 +338,6 @@ void qvm_split_k( compute_encoder.set_bytes(final_block_size, 15); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - d.add_temporaries(std::move(copies), s.index); int axis = intermediate.ndim() - 3; ReductionPlan plan( @@ -269,170 +348,589 @@ void qvm_split_k( intermediate, out, "sum", plan, {axis}, compute_encoder, d, s); } -void qmm_op( - const std::vector& inputs, +void qvm( + const array& x, + const array& w, + const array& scales, + const array& biases, + array& out, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int B = out.size() / M / N; + + int bn = 64; + int bk = 32; + MTL::Size group_dims(bk, 2, 1); + MTL::Size grid_dims(M, (N + bn - 1) / bn, B); + + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + "qvm_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + B > 1 ? "_batch_1" : "_batch_0"); + auto template_def = get_template_definition( + kname, "qvm", type_string, group_size, bits, B > 1); + + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_output_array(out, 4); + compute_encoder.set_bytes(K, 5); + compute_encoder.set_bytes(N, 6); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void qmm( + const array& x, + const array& w, + const array& scales, + const array& biases, array& out, bool transpose, int group_size, int bits, - bool gather, + int M, + int N, + int K, + metal::Device& d, const Stream& s) { - out.set_data(allocator::malloc(out.nbytes())); + int B = out.size() / M / N; - MTL::Size group_dims; - MTL::Size grid_dims; + int wm = 2; + int wn = 2; + int bm = 32; + int bn = 32; + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); - auto& x = inputs[0]; - auto& w = inputs[1]; - bool batched = !gather && (w.ndim() > 2 || !x.flags().row_contiguous); + std::string kname; + kname.reserve(64); + bool aligned = N % 32 == 0; + bool batched = B > 1; + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + transpose ? "qmm_t_" : "qmm_n_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + transpose ? (aligned ? "_alN_true" : "_alN_false") : "", + batched ? "_batch_1" : "_batch_0"); + std::string template_def; + if (transpose) { + template_def = get_template_definition( + kname, "qmm_t", type_string, group_size, bits, aligned, batched); + } else { + template_def = get_template_definition( + kname, "qmm_n", type_string, group_size, bits, batched); + } - int D = x.shape(-1); - int O = out.shape(-1); - // For the unbatched W case, avoid `adjust_matrix_offsets` - // for a small performance gain. - int B = (batched || gather) ? x.shape(-2) : x.size() / D; - int N = (batched || gather) ? out.size() / B / O : 1; + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); - std::string name = gather ? "bs_" : ""; - bool matrix = false; - bool aligned = false; - bool quad = false; + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_output_array(out, 4); + compute_encoder.set_bytes(K, 5); + compute_encoder.set_bytes(N, 6); + compute_encoder.set_bytes(M, 7); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 8); - auto get_qmv_batch_limit = [s](int D, int O) { - auto arch = metal::device(s.device).get_architecture(); - auto arch_size = arch.back(); - auto arch_gen = arch.substr(arch.size() - 3, 2); - if (arch_gen == "13" || arch_gen == "14") { - switch (arch_size) { - case 'd': - if (D <= 2048 && O <= 2048) { - return 32; - } else if (D <= 4096 && O <= 4096) { - return 18; - } else { - return 12; - } - default: - if (D <= 2048 && O <= 2048) { - return 14; - } else if (D <= 4096 && O <= 4096) { - return 10; - } else { - return 6; - } - } - } else { - switch (arch_size) { - case 'd': - if (D <= 2048 && O <= 2048) { - return 32; - } else if (D <= 4096 && O <= 4096) { - return 18; - } else { - return 12; - } - default: - if (D <= 2048 && O <= 2048) { - return 18; - } else if (D <= 4096 && O <= 4096) { - return 12; - } else { - return 10; - } - } + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_qmm( + const array& x, + const array& w, + const array& scales, + const array& biases, + const array& lhs_indices, + const array& rhs_indices, + array& out, + bool transpose, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int B = out.size() / M / N; + + int wm = 2; + int wn = 2; + int bm = 32; + int bn = 32; + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); + + std::string kname; + kname.reserve(64); + bool aligned = N % 32 == 0; + bool batched = B > 1; + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + transpose ? "gather_qmm_t_" : "gather_qmm_n_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + transpose ? (aligned ? "_alN_true" : "_alN_false") : ""); + std::string template_def; + if (transpose) { + template_def = get_template_definition( + kname, "gather_qmm_t", type_string, group_size, bits, aligned); + } else { + template_def = get_template_definition( + kname, "gather_qmm_n", type_string, group_size, bits); + } + + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_input_array(lhs_indices, 4); + compute_encoder.set_input_array(rhs_indices, 5); + compute_encoder.set_output_array(out, 6); + compute_encoder.set_bytes(K, 7); + compute_encoder.set_bytes(N, 8); + compute_encoder.set_bytes(M, 9); + int n = + add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 10); + add_gather_strides_and_shapes( + compute_encoder, lhs_indices, rhs_indices, 10 + n); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_qmv( + const array& x, + const array& w, + const array& scales, + const array& biases, + const array& lhs_indices, + const array& rhs_indices, + array& out, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int B = out.size() / M / N; + + int bn = 8; + int bk = 32; + MTL::Size group_dims(bk, 2, 1); + MTL::Size grid_dims(M, (N + bn - 1) / bn, B); + + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + bool fast = N % bn == 0 && K % 512 == 0; + concatenate( + kname, + fast ? "gather_qmv_fast_" : "gather_qmv_", + type_string, + "_gs_", + group_size, + "_b_", + bits); + auto template_def = get_template_definition( + kname, + fast ? "gather_qmv_fast" : "gather_qmv", + type_string, + group_size, + bits); + + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_input_array(lhs_indices, 4); + compute_encoder.set_input_array(rhs_indices, 5); + compute_encoder.set_output_array(out, 6); + compute_encoder.set_bytes(K, 7); + compute_encoder.set_bytes(N, 8); + int n = + add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9); + add_gather_strides_and_shapes( + compute_encoder, lhs_indices, rhs_indices, 9 + n); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_qvm( + const array& x, + const array& w, + const array& scales, + const array& biases, + const array& lhs_indices, + const array& rhs_indices, + array& out, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + int B = out.size() / M / N; + + int bn = 64; + int bk = 32; + MTL::Size group_dims(bk, 2, 1); + MTL::Size grid_dims(M, (N + bn - 1) / bn, B); + + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, "gather_qvm_", type_string, "_gs_", group_size, "_b_", bits); + auto template_def = get_template_definition( + kname, "gather_qvm", type_string, group_size, bits); + + auto kernel = get_quantized_kernel(d, kname, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(biases, 2); + compute_encoder.set_input_array(x, 3); + compute_encoder.set_input_array(lhs_indices, 4); + compute_encoder.set_input_array(rhs_indices, 5); + compute_encoder.set_output_array(out, 6); + compute_encoder.set_bytes(K, 7); + compute_encoder.set_bytes(N, 8); + int n = + add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9); + add_gather_strides_and_shapes( + compute_encoder, lhs_indices, rhs_indices, 9 + n); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_qmm_rhs( + const array& x_, + const array& w_, + const array& scales_, + const array& biases_, + const array& indices_, + array& out, + bool transpose, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s) { + // Start by normalizing the indices + array indices = ensure_row_contiguous(indices_, d, s); + + // Broadcast x with indices. If we are here that means lhs_indices were not + // provided so the lhs_indices are implied to be the shape of x broadcasted + // with rhs_indices. We need only broadcast x and copy it as if applying the + // lhs_indices. + auto broadcast_with_indices = [&d, &s, &indices](const array& x) { + if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) { + return ensure_row_contiguous(x, d, s); } + + auto x_shape = indices.shape(); + x_shape.push_back(x.shape(-2)); + x_shape.push_back(x.shape(-1)); + array new_x(std::move(x_shape), x.dtype(), nullptr, {}); + broadcast(x, new_x); + return ensure_row_contiguous(new_x, d, s); }; - if (transpose) { - auto qmv_batch_limit = get_qmv_batch_limit(D, O); - if (B < qmv_batch_limit && (D == 128 || D == 64) && is_power_of_2(bits)) { - name += "qmv_quad"; - constexpr int quads_per_simd = 8; - constexpr int results_per_quadgroup = 8; - int bo = quads_per_simd * results_per_quadgroup; - int simdgroup_size = 32; - group_dims = MTL::Size(simdgroup_size, 1, 1); - grid_dims = MTL::Size(B, (O + bo - 1) / bo, N); - quad = true; - } else if (B < qmv_batch_limit && O % 8 == 0 && D % 512 == 0 && D >= 512) { - name += "qmv_fast"; - int bo = 8; - int bd = 32; - group_dims = MTL::Size(bd, 2, 1); - grid_dims = MTL::Size(B, O / bo, N); - } else if (B < qmv_batch_limit) { - name += "qmv"; - int bo = 8; - int bd = 32; - group_dims = MTL::Size(bd, 2, 1); - grid_dims = MTL::Size(B, (O + bo - 1) / bo, N); - } else { - int wn = 2; - int wm = 2; - int bm = 32; - int bn = 32; - group_dims = MTL::Size(32, wn, wm); - grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N); - name += "qmm_t"; - matrix = true; - aligned = true; - } - } else { - if (B < 4 && D >= 1024 && !gather) { - return qvm_split_k(inputs, out, group_size, bits, D, O, B, N, s); - } else if (B < 4) { - name += "qvm"; - int bo = 64; - int bd = 32; - group_dims = MTL::Size(bd, 2, 1); - grid_dims = MTL::Size(B, O / bo, N); - } else { - name += "qmm_n"; - int wn = 2; - int wm = 2; - int bm = 32; - int bn = 32; - group_dims = MTL::Size(32, wn, wm); - grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N); - matrix = true; - if ((O % bn) != 0) { - std::ostringstream msg; - msg << "[quantized_matmul] The output size should be divisible by " - << bn << " but received " << O << "."; - throw std::runtime_error(msg.str()); - } - } - } - launch_qmm( - name, - inputs, - out, + // Normalize the input arrays + array x = broadcast_with_indices(x_); + array w = ensure_row_contiguous(w_, d, s); + array scales = ensure_row_contiguous(scales_, d, s); + array biases = ensure_row_contiguous(biases_, d, s); + + // TODO: Tune the block sizes + int bm = 16, bn = 32, bk = 32; + int wm = 1, wn = 2; + + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; + + // Make the kernel name + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + transpose ? "gather_qmm_rhs_nt_" : "gather_qmm_rhs_nn_", + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_bm_", + bm, + "_bn_", + bn, + "_bk_", + bk, + "_wm_", + wm, + "_wn_", + wn); + + metal::MTLFCList func_consts = { + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, + }; + + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + kname, + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n', + "_align_K_", + align_K ? 't' : 'n'); + + // Get and set the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_gather_qmm_kernel( + d, + kname, + hash_name, + func_consts, + x, group_size, bits, - D, - O, - B, - N, - group_dims, - grid_dims, - batched, - matrix, - gather, - aligned, - quad, - s); + bm, + bn, + bk, + wm, + wn, + transpose); + compute_encoder.set_compute_pipeline_state(kernel); + + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1); + + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(w, 1); + compute_encoder.set_input_array(scales, 2); + compute_encoder.set_input_array(biases, 3); + compute_encoder.set_input_array(indices, 4); + compute_encoder.set_output_array(out, 5); + compute_encoder.set_bytes(M, 6); + compute_encoder.set_bytes(N, 7); + compute_encoder.set_bytes(K, 8); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 4); - qmm_op( - inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream()); + auto& s = stream(); + auto& d = metal::device(s.device); + + out.set_data(allocator::malloc(out.nbytes())); + + // Make sure the last two dims of x and w, s, b are contiguous. This should + // be relaxed for x. + array x = ensure_row_contiguous_matrix(inputs[0], d, s); + array w = ensure_row_contiguous_matrix(inputs[1], d, s); + array scales = ensure_row_contiguous_matrix(inputs[2], d, s); + array biases = ensure_row_contiguous_matrix(inputs[3], d, s); + + // Extract the matmul shapes + bool non_batched = w.ndim() == 2 && x.flags().row_contiguous; + int K = x.shape(-1); + int M = non_batched ? x.size() / K : x.shape(-2); + int N = out.shape(-1); + + int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4; + + // It is a matrix matrix product. + if (M >= vector_limit) { + qmm(x, + w, + scales, + biases, + out, + transpose_, + group_size_, + bits_, + M, + N, + K, + d, + s); + return; + } + + // It is a qmv with a small inner dimension so route to qmv_quad kernel + if (transpose_ && (K == 128 || K == 64) && is_power_of_2(bits_)) { + qmv_quad(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + return; + } + + // Run of the mill qmv + if (transpose_) { + qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + return; + } + + // Run of the mill qvm + if (K < 1024) { + qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + return; + } + + // Qvm with large dimension so route to a split K kernel for more parallelism + qvm_split_k(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + return; } void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 6); - qmm_op( - inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream()); + auto& s = stream(); + auto& d = metal::device(s.device); + + out.set_data(allocator::malloc(out.nbytes())); + + array x = ensure_row_contiguous_matrix(inputs[0], d, s); + array w = ensure_row_contiguous_matrix(inputs[1], d, s); + array scales = ensure_row_contiguous_matrix(inputs[2], d, s); + array biases = ensure_row_contiguous_matrix(inputs[3], d, s); + const array& lhs_indices = inputs[4]; + const array& rhs_indices = inputs[5]; + + int K = x.shape(-1); + int M = x.shape(-2); + int N = out.shape(-1); + int B = out.size() / M / N; + int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4; + + // We are walking x in order and w is also in order so we can batch up the + // matmuls and reuse reading x and w. + // + // TODO: Tune 16 here a bit better. Maybe also choose it dynamically based + // on B and (w.size() / K / N). + if (M == 1 && B >= 16 && right_sorted_ == true) { + gather_qmm_rhs( + x, + w, + scales, + biases, + rhs_indices, + out, + transpose_, + group_size_, + bits_, + x.size() / K, + N, + K, + d, + s); + return; + } + + // It is a matrix matrix product + if (M >= vector_limit) { + gather_qmm( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + transpose_, + group_size_, + bits_, + M, + N, + K, + d, + s); + return; + } + + if (transpose_) { + gather_qmv( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + group_size_, + bits_, + M, + N, + K, + d, + s); + return; + } + + gather_qvm( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + group_size_, + bits_, + M, + N, + K, + d, + s); } void fast::AffineQuantize::eval_gpu( @@ -444,27 +942,13 @@ void fast::AffineQuantize::eval_gpu( auto& s = stream(); auto& d = metal::device(s.device); - - std::vector copies; - auto ensure_row_contiguous = [&copies, &s](const array& arr) { - if (arr.flags().row_contiguous) { - return arr; - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - return arr_copy; - } - }; - auto w = ensure_row_contiguous(w_pre); - auto& compute_encoder = d.get_command_encoder(s.index); + + auto w = ensure_row_contiguous(w_pre, d, s); compute_encoder.set_input_array(w, 0); if (dequantize_) { - auto& scales_pre = inputs[1]; - auto& biases_pre = inputs[2]; - auto scales = ensure_row_contiguous(scales_pre); - auto biases = ensure_row_contiguous(biases_pre); + auto scales = ensure_row_contiguous(inputs[1], d, s); + auto biases = ensure_row_contiguous(inputs[2], d, s); compute_encoder.set_input_array(scales, 1); compute_encoder.set_input_array(biases, 2); compute_encoder.set_output_array(out, 3); @@ -512,8 +996,6 @@ void fast::AffineQuantize::eval_gpu( MTL::Size grid_dims = use_2d ? get_2d_grid_dims(grid_shape, w.strides()) : MTL::Size(nthreads, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 1946a43fa..2f92088aa 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4028,6 +4028,7 @@ array gather_qmm( bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, + bool sorted_indices /* = false */, StreamOrDevice s /* = {} */) { if (!lhs_indices_ && !rhs_indices_) { return quantized_matmul( @@ -4067,13 +4068,19 @@ array gather_qmm( return array( std::move(out_shape), out_type, - std::make_shared(to_stream(s), group_size, bits, transpose), + std::make_shared( + to_stream(s), + group_size, + bits, + transpose, + sorted_indices && !rhs_indices_, + sorted_indices && !lhs_indices_), {astype(x, out_type, s), - w, + std::move(w), astype(scales, out_type, s), astype(biases, out_type, s), - lhs_indices, - rhs_indices}); + std::move(lhs_indices), + std::move(rhs_indices)}); } array tensordot( diff --git a/mlx/ops.h b/mlx/ops.h index f6fd958b3..e79ea235d 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1352,6 +1352,7 @@ array gather_qmm( bool transpose = true, int group_size = 64, int bits = 4, + bool sorted_indices = false, StreamOrDevice s = {}); /** Returns a contraction of a and b over multiple dimensions. */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 9b34fe657..590af60f6 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3080,6 +3080,8 @@ std::vector GatherQMM::vjp( auto& lhs_indices = primals[4]; auto& rhs_indices = primals[5]; + bool sorted = left_sorted_ || right_sorted_; + for (auto arg : argnums) { // gradient wrt to x if (arg == 0) { @@ -3098,6 +3100,7 @@ std::vector GatherQMM::vjp( !transpose_, group_size_, bits_, + sorted, stream()), -3, stream()), diff --git a/mlx/primitives.h b/mlx/primitives.h index 1902a562d..997931f30 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1591,11 +1591,19 @@ class QuantizedMatmul : public UnaryPrimitive { class GatherQMM : public UnaryPrimitive { public: - explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose) + explicit GatherQMM( + Stream stream, + int group_size, + int bits, + bool transpose, + bool left_sorted = false, + bool right_sorted = false) : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), - transpose_(transpose) {} + transpose_(transpose), + left_sorted_(left_sorted), + right_sorted_(right_sorted) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1605,13 +1613,16 @@ class GatherQMM : public UnaryPrimitive { DEFINE_PRINT(GatherQMM) bool is_equivalent(const Primitive& other) const override; auto state() const { - return std::make_tuple(group_size_, bits_, transpose_); + return std::make_tuple( + group_size_, bits_, transpose_, left_sorted_, right_sorted_); } private: int group_size_; int bits_; bool transpose_; + bool left_sorted_; + bool right_sorted_; }; class RandomBits : public UnaryPrimitive { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 8798ba482..f98aa80aa 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4250,9 +4250,10 @@ void init_ops(nb::module_& m) { "group_size"_a = 64, "bits"_a = 4, nb::kw_only(), + "sorted_indices"_a = false, "stream"_a = nb::none(), nb::sig( - "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform quantized matrix multiplication with matrix-level gather. @@ -4265,23 +4266,25 @@ void init_ops(nb::module_& m) { as ``w`` since they represent the same quantized matrix. Args: - x (array): Input array - w (array): Quantized matrix packed in unsigned integers - scales (array): The scales to use per ``group_size`` elements of ``w`` - biases (array): The biases to use per ``group_size`` elements of ``w`` - lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``. - rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``. - transpose (bool, optional): Defines whether to multiply with the - transposed ``w`` or not, namely whether we are performing - ``x @ w.T`` or ``x @ w``. Default: ``True``. - group_size (int, optional): The size of the group in ``w`` that - shares a scale and bias. Default: ``64``. - bits (int, optional): The number of bits occupied by each element in - ``w``. Default: ``4``. + x (array): Input array + w (array): Quantized matrix packed in unsigned integers + scales (array): The scales to use per ``group_size`` elements of ``w`` + biases (array): The biases to use per ``group_size`` elements of ``w`` + lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``. + rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``. + transpose (bool, optional): Defines whether to multiply with the + transposed ``w`` or not, namely whether we are performing + ``x @ w.T`` or ``x @ w``. Default: ``True``. + group_size (int, optional): The size of the group in ``w`` that + shares a scale and bias. Default: ``64``. + bits (int, optional): The number of bits occupied by each element in + ``w``. Default: ``4``. + sorted_indices (bool, optional): May allow a faster implementation + if the passed indices are sorted. Default: ``False``. Returns: - array: The result of the multiplication of ``x`` with ``w`` - after gathering using ``lhs_indices`` and ``rhs_indices``. + array: The result of the multiplication of ``x`` with ``w`` + after gathering using ``lhs_indices`` and ``rhs_indices``. )pbdoc"); m.def( "tensordot", @@ -4311,16 +4314,16 @@ void init_ops(nb::module_& m) { Compute the tensor dot product along the specified axes. Args: - a (array): Input array - b (array): Input array - axes (int or list(list(int)), optional): The number of dimensions to - sum over. If an integer is provided, then sum over the last - ``axes`` dimensions of ``a`` and the first ``axes`` dimensions of - ``b``. If a list of lists is provided, then sum over the - corresponding dimensions of ``a`` and ``b``. Default: 2. + a (array): Input array + b (array): Input array + axes (int or list(list(int)), optional): The number of dimensions to + sum over. If an integer is provided, then sum over the last + ``axes`` dimensions of ``a`` and the first ``axes`` dimensions of + ``b``. If a list of lists is provided, then sum over the + corresponding dimensions of ``a`` and ``b``. Default: 2. Returns: - array: The tensor dot product. + array: The tensor dot product. )pbdoc"); m.def( "inner", diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 160eb6400..eeefcd94f 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -174,12 +174,14 @@ class TestQuantized(mlx_tests.MLXTestCase): tests = product( [128, 64, 32], # group_size [2, 3, 4, 6, 8], # bits - [128, 256], # M + [32, 128, 256], # M [128, 256, 67], # N [0, 1, 3, 8], # B ) for group_size, bits, M, N, B in tests: with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits): + if M < group_size: + continue x_shape = (1, N) if B == 0 else (B, 1, N) w_shape = (N, M) if B == 0 else (B, N, M) x = mx.random.normal(shape=x_shape, key=k1) @@ -448,6 +450,7 @@ class TestQuantized(mlx_tests.MLXTestCase): ) for kwargs in inputs: + test_shape(1, 32, 128, **kwargs) test_shape(32, 32, 256, **kwargs) test_shape(1, 32, 256, **kwargs) test_shape(32, 256, 32, transpose=False, **kwargs) @@ -486,6 +489,66 @@ class TestQuantized(mlx_tests.MLXTestCase): g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices) self.assertTrue(mx.allclose(g1, g2, atol=1e-4)) + def test_gather_qmm_sorted(self): + def quantize(w, transpose=True, group_size=64, bits=4): + qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) + w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) + if transpose: + w_hat = w_hat.swapaxes(-1, -2) + return w_hat, qw, s, b + + def gather_sort(x, indices): + N, M = indices.shape + indices = indices.flatten() + order = mx.argsort(indices) + inv_order = mx.argsort(order) + return x.flatten(0, -3)[order // M], indices[order], inv_order + + def scatter_unsort(x, inv_order, shape=None): + x = x[inv_order] + if shape is not None: + x = mx.unflatten(x, 0, shape) + return x + + parameters = [ + # L, K, D, E, I, transpose + (128, 1024, 1024, 32, 4, True), + (128, 1024, 544, 32, 4, True), + (433, 1024, 1024, 32, 4, True), + (433, 1024, 555, 32, 4, True), + (433, 2048, 1024, 32, 4, True), + (128, 1024, 1024, 32, 4, False), + (128, 1024, 544, 32, 4, False), + (433, 1024, 1024, 32, 4, False), + (433, 1024, 544, 32, 4, False), + (433, 1024, 555, 32, 4, False), + (433, 2048, 1024, 32, 4, False), + ] + for L, K, D, E, I, transpose in parameters: + K, D = (K, D) if transpose else (D, K) + ishape = (L, I) + xshape = (L, 1, 1, K) + wshape = (E, D, K) if transpose else (E, K, D) + + indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32) + x = mx.random.normal(xshape) / K**0.5 + w = mx.random.normal(wshape) / K**0.5 + w, *wq = quantize(w, transpose=transpose) + + y1 = mx.gather_mm(x, w, rhs_indices=indices) + y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices) + xs, idx, inv_order = gather_sort(x, indices) + y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True) + y4 = mx.gather_qmm( + xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True + ) + y3 = scatter_unsort(y3, inv_order, indices.shape) + y4 = scatter_unsort(y4, inv_order, indices.shape) + + self.assertTrue(mx.allclose(y1, y2, atol=1e-5)) + self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) + self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) + if __name__ == "__main__": unittest.main()