diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 737ce27b5..15137a8a7 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -4,6 +4,7 @@ import argparse import math import os import time +from functools import partial import mlx.core as mx import mlx.nn as nn @@ -59,15 +60,23 @@ def matmul(x, y): mx.eval(ys) -def quant_matmul(x, w, s, b): - groups = x.shape[-1] // s.shape[-1] - width = 32 // (x.shape[-1] // w.shape[0]) +def _quant_matmul(x, w, s, b, group_size, bits): ys = [] for i in range(10): - ys.append(mx.quantized_matmul(x, w, s, b, groups=groups, width=width)) + ys.append(mx.quantized_matmul(x, w, s, b, group_size=group_size, bits=bits)) mx.eval(ys) +quant_matmul = { + "quant_matmul_64_2": partial(_quant_matmul, group_size=64, bits=2), + "quant_matmul_64_4": partial(_quant_matmul, group_size=64, bits=4), + "quant_matmul_64_8": partial(_quant_matmul, group_size=64, bits=8), + "quant_matmul_128_2": partial(_quant_matmul, group_size=128, bits=2), + "quant_matmul_128_4": partial(_quant_matmul, group_size=128, bits=4), + "quant_matmul_128_8": partial(_quant_matmul, group_size=128, bits=8), +} + + def conv1d(x, y): ys = [] for i in range(10): @@ -356,8 +365,8 @@ if __name__ == "__main__": elif args.benchmark == "matmul": print(bench(matmul, *xs)) - elif args.benchmark == "quant_matmul": - print(bench(quant_matmul, *xs)) + elif args.benchmark.startswith("quant_matmul"): + print(bench(quant_matmul[args.benchmark], *xs)) elif args.benchmark == "linear": print(bench(linear, *xs)) diff --git a/mlx/backend/accelerate/quantized.cpp b/mlx/backend/accelerate/quantized.cpp index 6246e1deb..cc31c88a3 100644 --- a/mlx/backend/accelerate/quantized.cpp +++ b/mlx/backend/accelerate/quantized.cpp @@ -76,20 +76,16 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { auto& scales = inputs[2]; auto& biases = inputs[3]; - if (w.strides()[0] != 1) { - throw std::runtime_error("The quantized weight should be transposed"); - } + bool condition = + (transpose_ && x.flags().row_contiguous && w.flags().row_contiguous && + scales.flags().row_contiguous && biases.flags().row_contiguous && + x.dtype() == float32 && bits_ == 4 && group_size_ == 64); - if (!x.flags().row_contiguous || !scales.flags().row_contiguous || - !biases.flags().row_contiguous) { - throw std::runtime_error("x, scales and biases should be row contiguous."); - } - - if (x.dtype() == float32 && bits_ == 4 && group_size_ == 64) { + if (condition) { out.set_data(allocator::malloc_or_wait(out.nbytes())); int K = x.shape(-1); int M = x.size() / K; - int N = w.shape(1); + int N = out.shape(-1); _qmm_t_4_64( out.data(), x.data(), diff --git a/mlx/backend/common/quantized.cpp b/mlx/backend/common/quantized.cpp index 1a9b27953..0ac2c2b61 100644 --- a/mlx/backend/common/quantized.cpp +++ b/mlx/backend/common/quantized.cpp @@ -1,13 +1,62 @@ // Copyright © 2023 Apple Inc. #include +#include +#include "mlx/backend/metal/copy.h" #include "mlx/primitives.h" namespace mlx::core { namespace { +template +void _qmm( + T* result, + const T* x, + const uint32_t* w, + const T* scales, + const T* biases, + int M, + int N, + int K) { + constexpr int bitmask = (1 << bits) - 1; + constexpr int pack_factor = 32 / bits; + constexpr int packs_in_group = group_size / pack_factor; + const int Ng = N / group_size; + const int Nw = N / pack_factor; + + for (int m = 0; m < M; m++) { + const uint32_t* w_local = w; + const T* scales_local = scales; + const T* biases_local = biases; + + std::fill(result, result + N, 0); + + for (int k = 0; k < K; k++) { + T* result_local = result; + T xi = *x++; + + for (int n = 0; n < N; n += group_size) { + T scale = *scales_local++; + T bias = *biases_local++; + for (int ng = 0; ng < packs_in_group; ng++) { + uint32_t wi = *w_local++; + +#pragma clang loop unroll(full) + for (int p = 0; p < pack_factor; p++) { + (*result_local++) += + xi * (scale * static_cast(wi & bitmask) + bias); + wi >>= bits; + } + } + } + } + + result += N; + } +} + template void _qmm_t( T* result, @@ -55,7 +104,7 @@ void _qmm_t( } template -void _qmm_t_dispatch_typed( +void _qmm_dispatch_typed( T* result, const T* x, const uint32_t* w, @@ -65,30 +114,55 @@ void _qmm_t_dispatch_typed( int N, int K, int group_size, - int bits) { + int bits, + bool transposed_w) { switch (bits) { case 2: { switch (group_size) { case 64: - return _qmm_t(result, x, w, scales, biases, M, N, K); + if (transposed_w) { + return _qmm_t(result, x, w, scales, biases, M, N, K); + } else { + return _qmm(result, x, w, scales, biases, M, N, K); + } case 128: - return _qmm_t(result, x, w, scales, biases, M, N, K); + if (transposed_w) { + return _qmm_t(result, x, w, scales, biases, M, N, K); + } else { + return _qmm(result, x, w, scales, biases, M, N, K); + } } } case 4: { switch (group_size) { case 64: - return _qmm_t(result, x, w, scales, biases, M, N, K); + if (transposed_w) { + return _qmm_t(result, x, w, scales, biases, M, N, K); + } else { + return _qmm(result, x, w, scales, biases, M, N, K); + } case 128: - return _qmm_t(result, x, w, scales, biases, M, N, K); + if (transposed_w) { + return _qmm_t(result, x, w, scales, biases, M, N, K); + } else { + return _qmm(result, x, w, scales, biases, M, N, K); + } } } case 8: { switch (group_size) { case 64: - return _qmm_t(result, x, w, scales, biases, M, N, K); + if (transposed_w) { + return _qmm_t(result, x, w, scales, biases, M, N, K); + } else { + return _qmm(result, x, w, scales, biases, M, N, K); + } case 128: - return _qmm_t(result, x, w, scales, biases, M, N, K); + if (transposed_w) { + return _qmm_t(result, x, w, scales, biases, M, N, K); + } else { + return _qmm(result, x, w, scales, biases, M, N, K); + } } } } @@ -100,21 +174,22 @@ void _qmm_t_dispatch_typed( throw std::invalid_argument(msg.str()); } -void _qmm_t_dispatch( +void _qmm_dispatch( array out, const array& x, const array& w, const array& scales, const array& biases, int bits, - int group_size) { + int group_size, + bool transposed_w) { int K = x.shape(-1); int M = x.size() / K; - int N = w.shape(1); + int N = out.shape(-1); switch (x.dtype()) { case float32: - _qmm_t_dispatch_typed( + _qmm_dispatch_typed( out.data(), x.data(), w.data(), @@ -124,10 +199,11 @@ void _qmm_t_dispatch( N, K, bits, - group_size); + group_size, + transposed_w); break; case float16: - _qmm_t_dispatch_typed( + _qmm_dispatch_typed( out.data(), x.data(), w.data(), @@ -137,10 +213,11 @@ void _qmm_t_dispatch( N, K, bits, - group_size); + group_size, + transposed_w); break; case bfloat16: - _qmm_t_dispatch_typed( + _qmm_dispatch_typed( out.data(), x.data(), w.data(), @@ -150,7 +227,8 @@ void _qmm_t_dispatch( N, K, bits, - group_size); + group_size, + transposed_w); break; default: throw std::invalid_argument( @@ -163,22 +241,28 @@ void _qmm_t_dispatch( void QuantizedMatmul::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 4); - auto& x = inputs[0]; - auto& w = inputs[1]; - auto& scales = inputs[2]; - auto& biases = inputs[3]; + auto& x_pre = inputs[0]; + auto& w_pre = inputs[1]; + auto& scales_pre = inputs[2]; + auto& biases_pre = inputs[3]; - if (w.strides()[0] != 1) { - throw std::runtime_error("The quantized weight should be transposed"); - } + auto ensure_row_contiguous = [](const array& arr) { + if (arr.flags().row_contiguous) { + return arr; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::General); + return arr_copy; + } + }; - if (!x.flags().row_contiguous || !scales.flags().row_contiguous || - !biases.flags().row_contiguous) { - throw std::runtime_error("x, scales and biases should be row contiguous."); - } + auto x = ensure_row_contiguous(x_pre); + auto w = ensure_row_contiguous(w_pre); + auto scales = ensure_row_contiguous(scales_pre); + auto biases = ensure_row_contiguous(biases_pre); out.set_data(allocator::malloc_or_wait(out.nbytes())); - _qmm_t_dispatch(out, x, w, scales, biases, group_size_, bits_); + _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); } } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 8a9e89450..9627dc3c0 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -104,6 +104,108 @@ template +[[kernel]] void qvm( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + + static_assert(BM == SIMD_SIZE, "qvm expects BM to be equal to SIMD_SIZE"); + static_assert(BN == BM, "qvm expects a block size of 32x32"); + + (void)lid; + + constexpr int bitmask = (1 << bits) - 1; + constexpr int el_per_int = 32 / bits; + constexpr int colgroup = BN * el_per_int; + constexpr int groups_per_block = colgroup / group_size; + + threadgroup T scales_block[BM * groups_per_block]; + threadgroup T biases_block[BM * groups_per_block]; + threadgroup T x_block[BM]; + + thread uint32_t w_local; + thread T result[el_per_int] = {0}; + thread T scale = 1; + thread T bias = 0; + thread T x_local = 0; + + // Adjust positions + const int out_vec_size_w = out_vec_size / el_per_int; + const int out_vec_size_g = out_vec_size / group_size; + int out_col = (tid.y * BN + simd_gid) * el_per_int; + w += out_col / el_per_int; + scales += out_col / group_size; + biases += out_col / group_size; + x += tid.z * in_vec_size; + y += tid.z * out_vec_size + out_col; + + if (out_col >= out_vec_size) { + return; + } + + // Loop over in_vec in blocks of colgroup + for (int i=0; i(w_local & bitmask) + bias) * x_local; + w_local >>= bits; + } + } + + // Accumulate in the simdgroup + #pragma clang loop unroll(full) + for (int k=0; k [[kernel]] void qmm_t( const device T* x [[buffer(0)]], @@ -133,8 +235,7 @@ template ; using loader_x_t = BlockLoader; @@ -231,8 +332,133 @@ template +[[kernel]] void qmm_n( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& M [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& K [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + const uint lidy = lid / SIMD_SIZE; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int bitmask = (1 << bits) - 1; + constexpr int el_per_int = 32 / bits; + constexpr int groups_per_block = (BN / group_size > 0) ? (BN / group_size) : 1; + constexpr int groups_per_simd = BK / (WM * WN); + constexpr int w_els_per_thread = (BK * BN / el_per_int) / (SIMD_SIZE * WM * WN); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = BlockMMA; + using loader_x_t = BlockLoader; + + threadgroup T scales_block[BK * groups_per_block]; + threadgroup T biases_block[BK * groups_per_block]; + threadgroup T Xs[BM * BK]; + threadgroup T Ws[BK * BN]; + + // Set the block + const int N_w = N / el_per_int; + const int N_g = N / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + x += y_row * K; + w += y_col / el_per_int; + scales += y_col / group_size; + biases += y_col / group_size; + y += y_row * N + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + for (int k=0; k(wi & bitmask) + bias; + wi >>= bits; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(Xs, Ws); + + // Prepare for next iteration + loader_x.next(); + w += BK * N_w; + scales += BK * N_g; + biases += BK * N_g; + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM) { + mma_op.store_result_safe(y, N, short2(BN, num_els)); + } else { + mma_op.store_result(y, N); + } +} + + #define instantiate_qmv(name, itype, group_size, bits) \ - template [[host_name("qmv_n_" #name "_gs_" #group_size "_b_" #bits)]] \ + template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \ [[kernel]] void qmv( \ const device uint32_t* w [[buffer(0)]], \ const device itype* scales [[buffer(1)]], \ @@ -258,6 +484,33 @@ instantiate_qmv_types( 64, 2) instantiate_qmv_types( 64, 4) instantiate_qmv_types( 64, 8) +#define instantiate_qvm(name, itype, group_size, bits) \ + template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \ + [[kernel]] void qvm( \ + const device itype* x [[buffer(0)]], \ + const device uint32_t* w [[buffer(1)]], \ + const device itype* scales [[buffer(2)]], \ + const device itype* biases [[buffer(3)]], \ + device itype* y [[buffer(4)]], \ + const constant int& in_vec_size [[buffer(5)]], \ + const constant int& out_vec_size [[buffer(6)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint lid [[thread_index_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_qvm_types(group_size, bits) \ + instantiate_qvm(float32, float, group_size, bits) \ + instantiate_qvm(float16, half, group_size, bits) \ + instantiate_qvm(bfloat16, bfloat16_t, group_size, bits) + +instantiate_qvm_types(128, 2) +instantiate_qvm_types(128, 4) +instantiate_qvm_types(128, 8) +instantiate_qvm_types( 64, 2) +instantiate_qvm_types( 64, 4) +instantiate_qvm_types( 64, 8) + #define instantiate_qmm_t(name, itype, group_size, bits) \ template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \ [[kernel]] void qmm_t( \ @@ -285,3 +538,31 @@ instantiate_qmm_t_types(128, 8) instantiate_qmm_t_types( 64, 2) instantiate_qmm_t_types( 64, 4) instantiate_qmm_t_types( 64, 8) + +#define instantiate_qmm_n(name, itype, group_size, bits) \ + template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \ + [[kernel]] void qmm_n( \ + const device itype* x [[buffer(0)]], \ + const device uint32_t* w [[buffer(1)]], \ + const device itype* scales [[buffer(2)]], \ + const device itype* biases [[buffer(3)]], \ + device itype* y [[buffer(4)]], \ + const constant int& M [[buffer(5)]], \ + const constant int& N [[buffer(6)]], \ + const constant int& K [[buffer(7)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint lid [[thread_index_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_qmm_n_types(group_size, bits) \ + instantiate_qmm_n(float32, float, group_size, bits) \ + instantiate_qmm_n(float16, half, group_size, bits) \ + instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits) + +instantiate_qmm_n_types(128, 2) +instantiate_qmm_n_types(128, 4) +instantiate_qmm_n_types(128, 8) +instantiate_qmm_n_types( 64, 2) +instantiate_qmm_n_types( 64, 4) +instantiate_qmm_n_types( 64, 8) diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 398bc8ed0..3997037e5 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -1,7 +1,6 @@ // Copyright © 2023 Apple Inc. #include -#include #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" @@ -23,97 +22,147 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { auto& biases_pre = inputs[3]; std::vector copies; - auto check_transpose = [&copies, &s](const array& arr) { - auto stx = arr.strides()[arr.ndim() - 2]; - auto sty = arr.strides()[arr.ndim() - 1]; - if (stx == arr.shape(-1) && sty == 1) { - return std::make_tuple(false, stx, arr); - } else if (stx == 1 && sty == arr.shape(-2)) { - return std::make_tuple(true, sty, arr); + auto ensure_row_contiguous = [&copies, &s](const array& arr) { + if (arr.flags().row_contiguous) { + return arr; } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy_gpu(arr, arr_copy, CopyType::General, s); copies.push_back(arr_copy); - size_t stx = arr.shape(-1); - return std::make_tuple(false, stx, arr_copy); + return arr_copy; } }; - auto [x_transposed, x_cols, x] = check_transpose(x_pre); - auto [w_transposed, w_cols, w] = check_transpose(w_pre); - auto [scales_transposed, scales_cols, scales] = check_transpose(scales_pre); - auto [biases_transposed, biases_cols, biases] = check_transpose(biases_pre); - - if (!w_transposed) { - throw std::runtime_error("The quantized weight should be transposed."); - } - - if (x_transposed || scales_transposed || biases_transposed) { - throw std::runtime_error("x, scales and biases should be row contiguous."); - } + auto x = ensure_row_contiguous(x_pre); + auto w = ensure_row_contiguous(w_pre); + auto scales = ensure_row_contiguous(scales_pre); + auto biases = ensure_row_contiguous(biases_pre); int D = x.shape(-1); int B = x.size() / D; + int O = out.shape(-1); + if (transpose_) { + // Route to the qmv kernel + if (B < 6) { + std::ostringstream kname; + kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" + << bits_; - // Route to the qmv kernel - if (B == 1) { - std::ostringstream kname; - kname << "qmv_" << (w_transposed ? "n_" : "t_") << type_to_name(out) - << "_gs_" << group_size_ << "_b_" << bits_; + // Encode and dispatch kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); - // Encode and dispatch kernel - auto compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder->setComputePipelineState(kernel); + int bo = 32; + int bd = 32; + MTL::Size group_dims = MTL::Size(bd, bo, 1); + MTL::Size grid_dims = MTL::Size(1, O / bo, B); - int O = w.size() / w_cols; + set_array_buffer(compute_encoder, w, 0); + set_array_buffer(compute_encoder, scales, 1); + set_array_buffer(compute_encoder, biases, 2); + set_array_buffer(compute_encoder, x, 3); + set_array_buffer(compute_encoder, out, 4); + compute_encoder->setBytes(&D, sizeof(int), 5); + compute_encoder->setBytes(&O, sizeof(int), 6); - int bo = 32; - int bd = 32; - MTL::Size group_dims = MTL::Size(bd, bo, 1); - MTL::Size grid_dims = MTL::Size(1, O / bo, B); + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } - set_array_buffer(compute_encoder, w, 0); - set_array_buffer(compute_encoder, scales, 1); - set_array_buffer(compute_encoder, biases, 2); - set_array_buffer(compute_encoder, x, 3); - set_array_buffer(compute_encoder, out, 4); - compute_encoder->setBytes(&D, sizeof(int), 5); - compute_encoder->setBytes(&O, sizeof(int), 6); + // Route to the qmm_t kernel + else { + std::ostringstream kname; + kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" + << bits_; - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); - } + // Encode and dispatch kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); - // Route to the qmm kernel - else { - std::ostringstream kname; - kname << "qmm_" << (w_transposed ? "t_" : "n_") << type_to_name(out) - << "_gs_" << group_size_ << "_b_" << bits_; + int wn = 2; + int wm = 2; + int bm = 32; + int bn = 32; + int bk = 64; + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1); - // Encode and dispatch kernel - auto compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); - compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, x, 0); + set_array_buffer(compute_encoder, w, 1); + set_array_buffer(compute_encoder, scales, 2); + set_array_buffer(compute_encoder, biases, 3); + set_array_buffer(compute_encoder, out, 4); + compute_encoder->setBytes(&B, sizeof(int), 5); + compute_encoder->setBytes(&O, sizeof(int), 6); + compute_encoder->setBytes(&D, sizeof(int), 7); - int O = w.size() / w_cols; + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } + } else { + // Route to the qvm kernel + if (B < 4) { + std::ostringstream kname; + kname << "qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" + << bits_; - int wn = 2; - int wm = 2; - int bm = 32; - int bn = 32; - int bk = 64; - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1); + // Encode and dispatch kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, x, 0); - set_array_buffer(compute_encoder, w, 1); - set_array_buffer(compute_encoder, scales, 2); - set_array_buffer(compute_encoder, biases, 3); - set_array_buffer(compute_encoder, out, 4); - compute_encoder->setBytes(&B, sizeof(int), 5); - compute_encoder->setBytes(&O, sizeof(int), 6); - compute_encoder->setBytes(&D, sizeof(int), 7); + int bo = 32; + int bd = 32; + MTL::Size group_dims = MTL::Size(bd, bo, 1); + MTL::Size grid_dims = MTL::Size(1, (w.shape(1) + bo - 1) / bo, B); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + set_array_buffer(compute_encoder, x, 0); + set_array_buffer(compute_encoder, w, 1); + set_array_buffer(compute_encoder, scales, 2); + set_array_buffer(compute_encoder, biases, 3); + set_array_buffer(compute_encoder, out, 4); + compute_encoder->setBytes(&D, sizeof(int), 5); + compute_encoder->setBytes(&O, sizeof(int), 6); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } + + // Route to the qmm_n kernel + else { + std::ostringstream kname; + kname << "qmm_n_" << type_to_name(out) << "_gs_" << group_size_ << "_b_" + << bits_; + + // Encode and dispatch kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int wn = 2; + int wm = 2; + int bm = 32; + int bn = 64; + int bk = 32; + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1); + + if ((O % bn) != 0) { + std::ostringstream msg; + msg << "[quantized_matmul] The output size should be divisible by " + << bn << " but received " << O << "."; + throw std::runtime_error(msg.str()); + } + + set_array_buffer(compute_encoder, x, 0); + set_array_buffer(compute_encoder, w, 1); + set_array_buffer(compute_encoder, scales, 2); + set_array_buffer(compute_encoder, biases, 3); + set_array_buffer(compute_encoder, out, 4); + compute_encoder->setBytes(&B, sizeof(int), 5); + compute_encoder->setBytes(&O, sizeof(int), 6); + compute_encoder->setBytes(&D, sizeof(int), 7); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } } d.get_command_buffer(s.index)->addCompletedHandler( diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e1f593aba..744aff68a 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2618,10 +2618,11 @@ array quantized_matmul( const array& w, const array& scales, const array& biases, + bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, StreamOrDevice s /* = {} */) { - auto x = in_x; + array x = in_x; if (w.dtype() != uint32) { std::ostringstream msg; @@ -2646,39 +2647,52 @@ array quantized_matmul( x = reshape(x, {-1, x_inner_dims}, s); } - int w_inner_dims = w.shape(0) * (32 / bits); - if (w_inner_dims != x_inner_dims) { + if (scales.ndim() != 2 || scales.shape() != biases.shape()) { std::ostringstream msg; - msg << "[quantized_matmul] Last dimension of first input with " - << "shape (..., " << x_inner_dims - << ") does not match the expanded first " - << "dimension of the quantized matrix " << w_inner_dims - << ", computed from shape " << w.shape() + msg << "[quantized_matmul] Scales and biases should have the same 2D shape. " + << "Received scales with shape " << scales.shape() + << " and biases with " << biases.shape(); + throw std::invalid_argument(msg.str()); + } + + if (w.shape(1) * 32 / bits != scales.shape(1) * group_size) { + std::ostringstream msg; + msg << "[quantized_matmul] The shapes of the weight and scales are " + << "incompatible based on bits and group_size. w.shape() == " + << w.shape() << " and scales.shape() == " << scales.shape() << " with group_size=" << group_size << " and bits=" << bits; throw std::invalid_argument(msg.str()); } - int n_groups = x_inner_dims / group_size; - if (scales.shape(-1) != n_groups || biases.shape(-1) != n_groups) { + // Calculate the expanded w's dims + int w_inner_dims = (transpose) ? w.shape(1) * 32 / bits : w.shape(0); + int w_outer_dims = (transpose) ? w.shape(0) : w.shape(1) * 32 / bits; + + if (w_inner_dims != x_inner_dims) { std::ostringstream msg; - msg << "[quantized_matmul] Scales and biases provided do not match the " - << "quantization arguments (group_size=" << group_size - << ", bits=" << bits << "). Expected shapes (" << w.shape(1) << ", " - << x_inner_dims / group_size - << "), but got scales.shape=" << scales.shape() - << " and biases.shape=" << biases.shape(); + msg << "[quantized_matmul] Last dimension of first input with " + << "shape (..., " << x_inner_dims << ") does not match " + << "the expanded quantized matrix (" << w_inner_dims << ", " + << w_outer_dims << ") computed from shape " << w.shape() + << " with group_size=" << group_size << ", bits=" << bits + << " and transpose=" << std::boolalpha << transpose; throw std::invalid_argument(msg.str()); } + auto dtype = result_type({x, scales, biases}); auto out = array( - {x.shape(0), w.shape(1)}, - x.dtype(), - std::make_unique(to_stream(s), group_size, bits), - {x, w, scales, biases}); + {x.shape(0), w_outer_dims}, + dtype, + std::make_unique( + to_stream(s), group_size, bits, transpose), + {astype(x, dtype, s), + w, + astype(scales, dtype, s), + astype(biases, dtype, s)}); // If needed reshape x to the original batch shape if (original_shape.size() != 1) { - original_shape.push_back(w.shape(1)); + original_shape.push_back(w_outer_dims); out = reshape(out, original_shape, s); } diff --git a/mlx/ops.h b/mlx/ops.h index 6516ad008..0f7b52da4 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1041,6 +1041,7 @@ array quantized_matmul( const array& w, const array& scales, const array& biases, + bool transpose = true, int group_size = 64, int bits = 4, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 3366e463c..ffcaefb44 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1706,14 +1706,37 @@ std::vector QuantizedMatmul::vjp( const std::vector& primals, const array& cotan, const std::vector& argnums) { - throw std::runtime_error("QuantizedMatmul::vjp NYI"); + std::vector vjps; + + // We rely on the fact that w is always 2D so transpose is simple + for (auto arg : argnums) { + // gradient wrt to x + if (arg == 0) { + vjps.push_back(quantized_matmul( + cotan, + primals[1], + primals[2], + primals[3], + !transpose_, + group_size_, + bits_, + stream())); + } + + // gradient wrt to w_q, scales or biases + else { + throw std::runtime_error( + "QuantizedMatmul::vjp no gradient wrt the quantized matrix yet."); + } + } + return vjps; } array QuantizedMatmul::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { - throw std::runtime_error("QuantizedMatmul::vjp NYI"); + throw std::runtime_error("QuantizedMatmul::jvp NYI"); } bool QuantizedMatmul::is_equivalent(const Primitive& other) const { diff --git a/mlx/primitives.h b/mlx/primitives.h index deeb498fa..e87e934ab 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1112,8 +1112,15 @@ class Power : public Primitive { class QuantizedMatmul : public Primitive { public: - explicit QuantizedMatmul(Stream stream, int group_size, int bits) - : Primitive(stream), group_size_(group_size), bits_(bits){}; + explicit QuantizedMatmul( + Stream stream, + int group_size, + int bits, + bool transpose) + : Primitive(stream), + group_size_(group_size), + bits_(bits), + transpose_(transpose){}; void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1129,6 +1136,7 @@ class QuantizedMatmul : public Primitive { private: int group_size_; int bits_; + bool transpose_; void eval(const std::vector& inputs, array& out); }; diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 6d2891db7..a52285633 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -81,9 +81,10 @@ class QuantizedLinear(Module): def __call__(self, x): x = mx.quantized_matmul( x, - self.weight.T, + self.weight, scales=self.scales, biases=self.biases, + transpose=True, group_size=self.group_size, bits=self.bits, ) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index c152eeb97..49a17a5c8 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3072,12 +3072,13 @@ void init_ops(py::module_& m) { py::pos_only(), "scales"_a, "biases"_a, + "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, py::kw_only(), "stream"_a = none, R"pbdoc( - quantized_matmul(x: array, w: array, scales: array, biases: array, /, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array + quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array Perform the matrix multiplication with the quantized matrix ``w``. The quantization uses one floating point scale and bias per ``group_size`` of @@ -3089,10 +3090,13 @@ void init_ops(py::module_& m) { w (array): Quantized matrix packed in unsigned integers scales (array): The scales to use per ``group_size`` elements of ``w`` biases (array): The biases to use per ``group_size`` elements of ``w`` + transpose (bool, optional): Defines whether to multiply with the + transposed ``w`` or not, namely whether we are performing + ``x @ w.T`` or ``x @ w``. (default: ``True``) group_size (int, optional): The size of the group in ``w`` that - shares a scale and bias. (default: 64) + shares a scale and bias. (default: ``64``) bits (int, optional): The number of bits occupied by each element in - ``w``. (default: 4) + ``w``. (default: ``4``) Returns: result (array): The result of the multiplication of ``x`` with ``w``. @@ -3146,9 +3150,9 @@ void init_ops(py::module_& m) { Args: w (array): Matrix to be quantized group_size (int, optional): The size of the group in ``w`` that shares a - scale and bias. (default: 64) + scale and bias. (default: ``64``) bits (int, optional): The number of bits occupied by each element of - ``w`` in the returned quantized matrix. (default: 4) + ``w`` in the returned quantized matrix. (default: ``4``) Returns: (tuple): A tuple containing @@ -3187,9 +3191,9 @@ void init_ops(py::module_& m) { scales (array): The scales to use per ``group_size`` elements of ``w`` biases (array): The biases to use per ``group_size`` elements of ``w`` group_size (int, optional): The size of the group in ``w`` that shares a - scale and bias. (default: 64) + scale and bias. (default: ``64``) bits (int, optional): The number of bits occupied by each element in - ``w``. (default: 4) + ``w``. (default: ``4``) Returns: result (array): The dequantized version of ``w`` diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 049f92fdb..5f038057d 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import unittest +from itertools import product import mlx.core as mx import mlx_tests @@ -19,62 +20,116 @@ class TestQuantized(mlx_tests.MLXTestCase): def test_qmm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) - for group_size in [128, 64]: - for bits in [2, 4, 8]: - for M in [8, 32, 33, 64]: - for N in [512, 1024]: - for K in [512, 1024]: - with self.subTest( - shape=(M, N, K), group_size=group_size, bits=bits - ): - x = mx.random.normal(shape=(M, K), key=k1) - w = mx.random.normal(shape=(N, K), key=k2) - w_q, scales, biases = mx.quantize(w, group_size, bits) - w_hat = mx.dequantize( - w_q, scales, biases, group_size, bits - ) - y_q = mx.quantized_matmul( - x, w_q.T, scales, biases, group_size, bits - ) - y_hat = x @ w_hat.T - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + tests = product( + [128, 64], # group_size + [2, 4, 8], # bits + [8, 32, 33, 64], # M + [512, 1024], # N + [512, 1024], # K + [True, False], # transposed + ) + for group_size, bits, M, N, K, transposed in tests: + with self.subTest( + shape=(M, N, K), + group_size=group_size, + bits=bits, + transposed=transposed, + ): + x = mx.random.normal(shape=(M, K), key=k1) + w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2) + w_q, scales, biases = mx.quantize(w, group_size, bits) + w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) + y_q = mx.quantized_matmul( + x, w_q, scales, biases, transposed, group_size, bits + ) + y_hat = (x @ w_hat.T) if transposed else (x @ w_hat) + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) def test_qmm_shapes(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) group_size = 64 bits = 4 - w = mx.random.normal(shape=(32, 128), key=k2) + w = mx.random.normal(shape=(32, 256), key=k2) w_q, scales, biases = mx.quantize(w, group_size, bits) w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) - for s in [(3, 128), (2, 1, 7, 128)]: - x = mx.random.normal(shape=(3, 128), key=k1) - y_q = mx.quantized_matmul(x, w_q.T, scales, biases, group_size, bits) + for s in [(3, 256), (2, 1, 7, 256)]: + x = mx.random.normal(shape=s, key=k1) + y_q = mx.quantized_matmul(x, w_q, scales, biases, True, group_size, bits) y_hat = x @ w_hat.T self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) + w = mx.random.normal(shape=(256, 256), key=k2) + w_q, scales, biases = mx.quantize(w, group_size, bits) + w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) + for s in [(3, 256), (2, 1, 7, 256)]: + x = mx.random.normal(shape=s, key=k1) + y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_qmv(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) - for group_size in [128, 64]: - for bits in [2, 4, 8]: - for M in [512, 1024]: - for N in [512, 1024]: - with self.subTest( - shape=(M, N), group_size=group_size, bits=bits - ): - x = mx.random.normal(shape=(1, N), key=k1) - w = mx.random.normal(shape=(M, N), key=k2) - w_q, scales, biases = mx.quantize(w, group_size, bits) - w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) - y_q = mx.quantized_matmul( - x, w_q.T, scales, biases, group_size, bits - ) - y_hat = x @ w_hat.T - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + tests = product( + [128, 64], # group_size + [2, 4, 8], # bits + [512, 1024], # M + [512, 1024], # N + ) + for group_size, bits, M, N in tests: + with self.subTest(shape=(M, N), group_size=group_size, bits=bits): + x = mx.random.normal(shape=(1, N), key=k1) + w = mx.random.normal(shape=(M, N), key=k2) + w_q, scales, biases = mx.quantize(w, group_size, bits) + w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) + y_q = mx.quantized_matmul( + x, w_q, scales, biases, True, group_size, bits + ) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + def test_qvm(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + tests = product( + [128, 64], # group_size + [2, 4, 8], # bits + [512, 1024], # M + [512, 1024], # N + ) + for group_size, bits, M, N in tests: + with self.subTest(shape=(M, N), group_size=group_size, bits=bits): + x = mx.random.normal(shape=(1, N), key=k1) + w = mx.random.normal(shape=(N, M), key=k2) + w_q, scales, biases = mx.quantize(w, group_size, bits) + w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) + y_q = mx.quantized_matmul( + x, w_q, scales, biases, False, group_size, bits + ) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + def test_throw(self): + x = mx.random.normal(shape=(10, 512)) + w = mx.random.normal(shape=(32, 512)) + w_q, scales, biases = mx.quantize(w) + + with self.assertRaises(ValueError): + mx.quantized_matmul(x, w_q.T, scales, biases) + with self.assertRaises(ValueError): + mx.quantized_matmul(x, w_q.T, scales.T, biases) + with self.assertRaises(ValueError): + mx.quantized_matmul(x, w_q, scales, biases, False) + with self.assertRaises(ValueError): + mx.quantized_matmul(x, w_q, scales.T, biases.T) + y = mx.quantized_matmul(x, w_q, scales, biases, True) + mx.eval(y) if __name__ == "__main__":