diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 71a3d6c01..737ce27b5 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -23,6 +23,16 @@ def none_or_list(x): return [int(xi) for xi in x.split(",")] +def dtype_from_str(x): + if x == "": + return mx.float32 + else: + dt = getattr(mx, x) + if not isinstance(dt, mx.Dtype): + raise ValueError(f"{x} is not an mlx dtype") + return dt + + def bench(f, *args): for i in range(10): f(*args) @@ -49,6 +59,15 @@ def matmul(x, y): mx.eval(ys) +def quant_matmul(x, w, s, b): + groups = x.shape[-1] // s.shape[-1] + width = 32 // (x.shape[-1] // w.shape[0]) + ys = [] + for i in range(10): + ys.append(mx.quantized_matmul(x, w, s, b, groups=groups, width=width)) + mx.eval(ys) + + def conv1d(x, y): ys = [] for i in range(10): @@ -296,9 +315,7 @@ if __name__ == "__main__": parser.add_argument( "--fused", action="store_true", help="Use fused functions where possible" ) - parser.add_argument( - "--dtype", choices=["float32", "float16", "bfloat16"], default="float32" - ) + parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append") args = parser.parse_args() @@ -315,11 +332,15 @@ if __name__ == "__main__": mx.set_default_device(mx.cpu) else: mx.set_default_device(mx.gpu) - dtype = dict(float32=mx.float32, float16=mx.float16, bfloat16=mx.bfloat16)[ - args.dtype - ] + + types = args.dtype + if not types: + types = [mx.float32] + if len(types) < len(args.size): + types = types + [types[0]] * (len(args.size) - len(types)) + xs = [] - for size in args.size: + for size, dtype in zip(args.size, types): xs.append(mx.random.normal(size).astype(dtype)) for i, t in enumerate(args.transpose): if t is None: @@ -335,6 +356,9 @@ if __name__ == "__main__": elif args.benchmark == "matmul": print(bench(matmul, *xs)) + elif args.benchmark == "quant_matmul": + print(bench(quant_matmul, *xs)) + elif args.benchmark == "linear": print(bench(linear, *xs)) diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index d556dd051..1e8649537 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -22,6 +22,16 @@ def none_or_list(x): return [int(xi) for xi in x.split(",")] +def dtype_from_str(x): + if x == "": + return torch.float32 + else: + dt = getattr(torch, x) + if not isinstance(dt, torch.dtype): + raise ValueError(f"{x} is not a torch dtype") + return dt + + def bench(f, *args): for i in range(10): f(*args) @@ -312,7 +322,7 @@ if __name__ == "__main__": parser.add_argument( "--fused", action="store_true", help="Use fused functions where possible" ) - parser.add_argument("--dtype", choices=["float32", "float16"], default="float32") + parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append") args = parser.parse_args() @@ -327,9 +337,15 @@ if __name__ == "__main__": torch.set_num_threads(1) device = "cpu" if args.cpu else "mps" - dtype = dict(float32=torch.float32, float16=torch.float16)[args.dtype] + + types = args.dtype + if not types: + types = [torch.float32] + if len(types) < len(args.size): + types = types + [types[0]] * (len(args.size) - len(types)) + xs = [] - for size in args.size: + for size, dtype in zip(args.size, types): xs.append(torch.randn(*size).to(device).to(dtype)) for i, t in enumerate(args.transpose): if t is None: diff --git a/mlx/backend/accelerate/CMakeLists.txt b/mlx/backend/accelerate/CMakeLists.txt index 34269f9c2..e3c16910a 100644 --- a/mlx/backend/accelerate/CMakeLists.txt +++ b/mlx/backend/accelerate/CMakeLists.txt @@ -4,6 +4,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ) diff --git a/mlx/backend/accelerate/quantized.cpp b/mlx/backend/accelerate/quantized.cpp new file mode 100644 index 000000000..dc545343e --- /dev/null +++ b/mlx/backend/accelerate/quantized.cpp @@ -0,0 +1,107 @@ +// Copyright © 2023 Apple Inc. + +#include + +#include + +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +void _qmm_t_4_64( + float* result, + const float* x, + const uint32_t* w, + const float* scales, + const float* biases, + int M, + int N, + int K) { + constexpr int width = 4; + constexpr int groups = 64; + constexpr int bitmask = (1 << width) - 1; + constexpr int pack_factor = 32 / width; + constexpr int packs_in_group = groups / pack_factor; + const int Kg = K / groups; + const int Kw = K / pack_factor; + + for (int m = 0; m < M; m++) { + const uint32_t* w_local = w; + const float* scales_local = scales; + const float* biases_local = biases; + + for (int n = 0; n < N; n++) { + const simd_float16* x_local = (simd_float16*)x; + simd_float16 sum = 0; + for (int k = 0; k < K; k += groups) { + float scale = *scales_local++; + float bias = *biases_local++; + + for (int kw = 0; kw < packs_in_group; kw += 2) { + // TODO: vectorize this properly + simd_uint16 wi; + for (int e = 0; e < 2; e++) { + uint32_t wii = *w_local++; + for (int p = 0; p < 8; p++) { + wi[e * 8 + p] = wii & bitmask; + wii >>= width; + } + } + simd_float16 wf = simd_float(wi); + wf *= scale; + wf += bias; + + sum += (*x_local) * wf; + x_local++; + } + } + + *result = simd_reduce_add(sum); + result++; + } + + x += K; + } +} + +} // namespace + +void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 4); + + auto& x = inputs[0]; + auto& w = inputs[1]; + auto& scales = inputs[2]; + auto& biases = inputs[3]; + + if (w.strides()[0] != 1) { + throw std::runtime_error("The quantized weight should be transposed"); + } + + if (!x.flags().row_contiguous || !scales.flags().row_contiguous || + !biases.flags().row_contiguous) { + throw std::runtime_error("x, scales and biases should be row contiguous."); + } + + if (x.dtype() == float32 && width_ == 4 && groups_ == 64) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + int K = x.shape(-1); + int M = x.size() / K; + int N = w.shape(1); + _qmm_t_4_64( + out.data(), + x.data(), + w.data(), + scales.data(), + biases.data(), + M, + N, + K); + } else { + eval(inputs, out); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 5ab4c2979..077a0353e 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -8,6 +8,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 126d953f5..3cefbbcb2 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -62,6 +62,7 @@ DEFAULT(NotEqual) DEFAULT(Pad) DEFAULT(Partition) DEFAULT(Power) +DEFAULT(QuantizedMatmul) DEFAULT(RandomBits) DEFAULT(Reduce) DEFAULT(Reshape) diff --git a/mlx/backend/common/quantized.cpp b/mlx/backend/common/quantized.cpp new file mode 100644 index 000000000..2120d881a --- /dev/null +++ b/mlx/backend/common/quantized.cpp @@ -0,0 +1,183 @@ +// Copyright © 2023 Apple Inc. + +#include + +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +template +void _qmm_t( + T* result, + const T* x, + const uint32_t* w, + const T* scales, + const T* biases, + int M, + int N, + int K) { + constexpr int bitmask = (1 << width) - 1; + constexpr int pack_factor = 32 / width; + constexpr int packs_in_group = groups / pack_factor; + const int Kg = K / groups; + const int Kw = K / pack_factor; + + for (int m = 0; m < M; m++) { + const uint32_t* w_local = w; + const T* scales_local = scales; + const T* biases_local = biases; + + for (int n = 0; n < N; n++) { + const T* x_local = x; + T sum = 0; + for (int k = 0; k < K; k += groups) { + T scale = *scales_local++; + T bias = *biases_local++; + + for (int kw = 0; kw < packs_in_group; kw++) { + uint32_t wi = *w_local++; + +#pragma clang loop unroll(full) + for (int p = 0; p < pack_factor; p++) { + sum += (*x_local++) * (scale * static_cast(wi & bitmask) + bias); + wi >>= width; + } + } + } + *result = sum; + result++; + } + + x += K; + } +} + +template +void _qmm_t_dispatch_typed( + T* result, + const T* x, + const uint32_t* w, + const T* scales, + const T* biases, + int M, + int N, + int K, + int width, + int groups) { + switch (width) { + case 2: { + switch (groups) { + case 64: + return _qmm_t(result, x, w, scales, biases, M, N, K); + case 128: + return _qmm_t(result, x, w, scales, biases, M, N, K); + } + } + case 4: { + switch (groups) { + case 64: + return _qmm_t(result, x, w, scales, biases, M, N, K); + case 128: + return _qmm_t(result, x, w, scales, biases, M, N, K); + } + } + case 8: { + switch (groups) { + case 64: + return _qmm_t(result, x, w, scales, biases, M, N, K); + case 128: + return _qmm_t(result, x, w, scales, biases, M, N, K); + } + } + } + std::ostringstream msg; + msg << "Quantization type not supported. Provided bit width=" << width + << " and groups=" << groups << ". The supported options are width in " + << "{2, 4, 8} and groups in {64, 128}."; + throw std::invalid_argument(msg.str()); +} + +void _qmm_t_dispatch( + array out, + const array& x, + const array& w, + const array& scales, + const array& biases, + int width, + int groups) { + int K = x.shape(-1); + int M = x.size() / K; + int N = w.shape(1); + + switch (x.dtype()) { + case float32: + _qmm_t_dispatch_typed( + out.data(), + x.data(), + w.data(), + scales.data(), + biases.data(), + M, + N, + K, + width, + groups); + break; + case float16: + _qmm_t_dispatch_typed( + out.data(), + x.data(), + w.data(), + scales.data(), + biases.data(), + M, + N, + K, + width, + groups); + break; + case bfloat16: + _qmm_t_dispatch_typed( + out.data(), + x.data(), + w.data(), + scales.data(), + biases.data(), + M, + N, + K, + width, + groups); + break; + default: + throw std::invalid_argument( + "[quantized_matmul] only floating types are supported"); + } +} + +} // namespace + +void QuantizedMatmul::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 4); + + auto& x = inputs[0]; + auto& w = inputs[1]; + auto& scales = inputs[2]; + auto& biases = inputs[3]; + + if (w.strides()[0] != 1) { + throw std::runtime_error("The quantized weight should be transposed"); + } + + if (!x.flags().row_contiguous || !scales.flags().row_contiguous || + !biases.flags().row_contiguous) { + throw std::runtime_error("x, scales and biases should be row contiguous."); + } + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + _qmm_t_dispatch(out, x, w, scales, biases, width_, groups_); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index d5fd2e07f..e12402c18 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -10,6 +10,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 8fc5eac30..e65430c25 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -18,6 +18,7 @@ set( "copy" "gemm" "gemv" + "quantized" "random" "reduce" "scan" diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal new file mode 100644 index 000000000..eb48c92f1 --- /dev/null +++ b/mlx/backend/metal/kernels/quantized.metal @@ -0,0 +1,287 @@ +// Copyright © 2023 Apple Inc. + +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/gemm/gemm.h" +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; + +template +[[kernel]] void qmv( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + + static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE"); + + constexpr int bitmask = (1 << width) - 1; + constexpr int el_per_thread = 32 / width; + constexpr int colgroup = BN * el_per_thread; + constexpr int groups_per_block = colgroup / groups; + constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE; + + threadgroup T scales_block[BM * groups_per_block]; + threadgroup T biases_block[BM * groups_per_block]; + threadgroup T x_block[colgroup]; + + thread uint32_t w_local; + thread T result = 0; + thread T scale = 1; + thread T bias = 0; + thread T x_thread[el_per_thread]; + + // Adjust positions + const int in_vec_size_w = in_vec_size / el_per_thread; + const int in_vec_size_g = in_vec_size / groups; + int out_row = tid.y * BM + simd_gid; + w += out_row * in_vec_size_w; + scales += out_row * in_vec_size_g; + biases += out_row * in_vec_size_g; + x += tid.z * in_vec_size; + y += tid.z * out_vec_size; + + // Loop over in_vec in blocks of colgroup + for (int i=0; i(w_local & bitmask) + bias) * x_thread[k]; + w_local >>= width; + } + } + + // Accumulate in the simdgroup + result = simd_sum(result); + + // Store the result + if (simd_lid == 0) { + y[out_row] = result; + } +} + + +template +[[kernel]] void qmm_t( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& M [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& K [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + const uint lidy = lid / SIMD_SIZE; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int bitmask = (1 << width) - 1; + constexpr int el_per_int = 32 / width; + constexpr int ints_per_block = BK / el_per_int; + constexpr int groups_per_block = (BK / groups > 0) ? (BK / groups) : 1; + constexpr int groups_per_simd = BN / (WM * WN); + constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN); + + // Using the kernel just as a type to instantiate the appropriate BlockMMA + // and constexpr size calculations + using mma_t = BlockMMA; + using loader_x_t = BlockLoader; + + threadgroup T scales_block[BN * groups_per_block]; + threadgroup T biases_block[BN * groups_per_block]; + threadgroup T Xs[BM * BK]; + threadgroup T Ws[BN * BK]; + + // Set the block + const int K_w = K / el_per_int; + const int K_g = K / groups; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + x += y_row * K; + w += y_col * K_w; + scales += y_col * K_g; + biases += y_col * K_g; + y += y_row * N + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + for (int k=0; k(wi & bitmask) + bias; + wi >>= width; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(Xs, Ws); + + // Prepare for next iteration + loader_x.next(); + w += ints_per_block; + // scales and biases cannot be advanced because they would have to be + // advanced every other iteration or sth. + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM) { + mma_op.store_result_safe(y, N, short2(BN, num_els)); + } else { + mma_op.store_result(y, N); + } +} + + +#define instantiate_qmv(name, itype, groups, width) \ + template [[host_name("qmv_n_" #name "_groups_" #groups "_width_" #width)]] \ + [[kernel]] void qmv( \ + const device uint32_t* w [[buffer(0)]], \ + const device itype* scales [[buffer(1)]], \ + const device itype* biases [[buffer(2)]], \ + const device itype* x [[buffer(3)]], \ + device itype* y [[buffer(4)]], \ + const constant int& in_vec_size [[buffer(5)]], \ + const constant int& out_vec_size [[buffer(6)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint lid [[thread_index_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_qmv_types(groups, width) \ + instantiate_qmv(float32, float, groups, width) \ + instantiate_qmv(float16, half, groups, width) \ + instantiate_qmv(bfloat16, bfloat16_t, groups, width) + +instantiate_qmv_types(128, 2) +instantiate_qmv_types(128, 4) +instantiate_qmv_types(128, 8) +instantiate_qmv_types( 64, 2) +instantiate_qmv_types( 64, 4) +instantiate_qmv_types( 64, 8) + +#define instantiate_qmm_t(name, itype, groups, width) \ + template [[host_name("qmm_t_" #name "_groups_" #groups "_width_" #width)]] \ + [[kernel]] void qmm_t( \ + const device itype* x [[buffer(0)]], \ + const device uint32_t* w [[buffer(1)]], \ + const device itype* scales [[buffer(2)]], \ + const device itype* biases [[buffer(3)]], \ + device itype* y [[buffer(4)]], \ + const constant int& M [[buffer(5)]], \ + const constant int& N [[buffer(6)]], \ + const constant int& K [[buffer(7)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint lid [[thread_index_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_qmm_t_types(groups, width) \ + instantiate_qmm_t(float32, float, groups, width) \ + instantiate_qmm_t(float16, half, groups, width) \ + instantiate_qmm_t(bfloat16, bfloat16_t, groups, width) + +instantiate_qmm_t_types(128, 2) +instantiate_qmm_t_types(128, 4) +instantiate_qmm_t_types(128, 8) +instantiate_qmm_t_types( 64, 2) +instantiate_qmm_t_types( 64, 4) +instantiate_qmm_t_types( 64, 8) diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp new file mode 100644 index 000000000..7d8225797 --- /dev/null +++ b/mlx/backend/metal/quantized.cpp @@ -0,0 +1,123 @@ +// Copyright © 2023 Apple Inc. + +#include +#include + +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 4); + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& x_pre = inputs[0]; + auto& w_pre = inputs[1]; + auto& scales_pre = inputs[2]; + auto& biases_pre = inputs[3]; + + std::vector copies; + auto check_transpose = [&copies, &s](const array& arr) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (stx == arr.shape(-1) && sty == 1) { + return std::make_tuple(false, stx, arr); + } else if (stx == 1 && sty == arr.shape(-2)) { + return std::make_tuple(true, sty, arr); + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + size_t stx = arr.shape(-1); + return std::make_tuple(false, stx, arr_copy); + } + }; + auto [x_transposed, x_cols, x] = check_transpose(x_pre); + auto [w_transposed, w_cols, w] = check_transpose(w_pre); + auto [scales_transposed, scales_cols, scales] = check_transpose(scales_pre); + auto [biases_transposed, biases_cols, biases] = check_transpose(biases_pre); + + if (!w_transposed) { + throw std::runtime_error("The quantized weight should be transposed."); + } + + if (x_transposed || scales_transposed || biases_transposed) { + throw std::runtime_error("x, scales and biases should be row contiguous."); + } + + int D = x.shape(-1); + int B = x.size() / D; + + // Route to the qmv kernel + if (B == 1) { + std::ostringstream kname; + kname << "qmv_" << (w_transposed ? "n_" : "t_") << type_to_name(out) + << "_groups_" << groups_ << "_width_" << width_; + + // Encode and dispatch kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int O = w.size() / w_cols; + + int bo = 32; + int bd = 32; + MTL::Size group_dims = MTL::Size(bd, bo, 1); + MTL::Size grid_dims = MTL::Size(1, O / bo, B); + + set_array_buffer(compute_encoder, w, 0); + set_array_buffer(compute_encoder, scales, 1); + set_array_buffer(compute_encoder, biases, 2); + set_array_buffer(compute_encoder, x, 3); + set_array_buffer(compute_encoder, out, 4); + compute_encoder->setBytes(&D, sizeof(int), 5); + compute_encoder->setBytes(&O, sizeof(int), 6); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } + + // Route to the qmm kernel + else { + std::ostringstream kname; + kname << "qmm_" << (w_transposed ? "t_" : "n_") << type_to_name(out) + << "_groups_" << groups_ << "_width_" << width_; + + // Encode and dispatch kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int O = w.size() / w_cols; + + int wn = 2; + int wm = 2; + int bm = 32; + int bn = 32; + int bk = 64; + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1); + + set_array_buffer(compute_encoder, x, 0); + set_array_buffer(compute_encoder, w, 1); + set_array_buffer(compute_encoder, scales, 2); + set_array_buffer(compute_encoder, biases, 3); + set_array_buffer(compute_encoder, out, 4); + compute_encoder->setBytes(&B, sizeof(int), 5); + compute_encoder->setBytes(&O, sizeof(int), 6); + compute_encoder->setBytes(&D, sizeof(int), 7); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } + + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); +} + +} // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index b9cde4426..e2f92b093 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -58,6 +58,7 @@ NO_GPU(NotEqual) NO_GPU(Pad) NO_GPU(Partition) NO_GPU(Power) +NO_GPU(QuantizedMatmul) NO_GPU(RandomBits) NO_GPU(Reduce) NO_GPU(Reshape) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c25ac28bf..9a4a9b2a9 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2564,4 +2564,75 @@ array conv2d( {in, wt}); } +array quantized_matmul( + const array& in_x, + const array& w, + const array& scales, + const array& biases, + int groups /* = 128 */, + int width /* = 4 */, + StreamOrDevice s /* = {} */) { + auto x = in_x; + + if (w.dtype() != uint32) { + std::ostringstream msg; + msg << "[quantized_matmul] The weight matrix should be uint32 " + << "but received" << w.dtype(); + throw std::invalid_argument(msg.str()); + } + if (w.ndim() != 2) { + std::ostringstream msg; + msg << "[quantized_matmul] Batched quantized matmul is not supported for now " + << "received w with shape " << w.shape(); + throw std::invalid_argument(msg.str()); + } + + // Keep x's batch dimensions to reshape it back after the matmul + auto original_shape = x.shape(); + int x_inner_dims = original_shape.back(); + original_shape.pop_back(); + + // Reshape x into a matrix if it isn't already one + if (x.ndim() != 2) { + x = reshape(x, {-1, x_inner_dims}, s); + } + + int w_inner_dims = w.shape(0) * (32 / width); + if (w_inner_dims != x_inner_dims) { + std::ostringstream msg; + msg << "[quantized_matmul] Last dimension of first input with " + << "shape (..., " << x_inner_dims + << ") does not match the expanded first " + << "dimension of the quantized matrix " << w_inner_dims + << ", computed from shape " << w.shape() << " with groups=" << groups + << " and width=" << width; + throw std::invalid_argument(msg.str()); + } + + int n_groups = x_inner_dims / groups; + if (scales.shape(-1) != n_groups || biases.shape(-1) != n_groups) { + std::ostringstream msg; + msg << "[quantized_matmul] Scales and biases provided do not match the " + << "quantization arguments (groups=" << groups << ", width=" << width + << "). Expected shapes (" << w.shape(1) << ", " << x_inner_dims / groups + << "), but got scales.shape=" << scales.shape() + << " and biases.shape=" << biases.shape(); + throw std::invalid_argument(msg.str()); + } + + auto out = array( + {x.shape(0), w.shape(1)}, + x.dtype(), + std::make_unique(to_stream(s), groups, width), + {x, w, scales, biases}); + + // If needed reshape x to the original batch shape + if (original_shape.size() != 1) { + original_shape.push_back(w.shape(1)); + out = reshape(out, original_shape, s); + } + + return out; +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index 50a0dc1eb..977dfc970 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1028,4 +1028,14 @@ array load(std::shared_ptr in_stream, StreamOrDevice s = {}); /** Load array from file in .npy format */ array load(const std::string& file, StreamOrDevice s = {}); +/** Quantized matmul multiplies x with a quantized matrix w*/ +array quantized_matmul( + const array& x, + const array& w, + const array& scales, + const array& biases, + int groups = 128, + int width = 4, + StreamOrDevice s = {}); + } // namespace mlx::core diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 6d275097c..f67340921 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1696,6 +1696,31 @@ std::pair Power::vmap( return {power(a, b, stream()), to_ax}; } +std::pair QuantizedMatmul::vmap( + const std::vector& inputs, + const std::vector& axes) { + throw std::runtime_error("QuantizedMatmul::vmap NYI"); +} + +std::vector QuantizedMatmul::vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) { + throw std::runtime_error("QuantizedMatmul::vjp NYI"); +} + +array QuantizedMatmul::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + throw std::runtime_error("QuantizedMatmul::vjp NYI"); +} + +bool QuantizedMatmul::is_equivalent(const Primitive& other) const { + const QuantizedMatmul& qm_other = static_cast(other); + return groups_ == qm_other.groups_ && width_ == qm_other.width_; +} + std::pair RandomBits::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 8916c4fa1..fbbde2dd1 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1110,6 +1110,29 @@ class Power : public Primitive { void eval(const std::vector& inputs, array& out); }; +class QuantizedMatmul : public Primitive { + public: + explicit QuantizedMatmul(Stream stream, int groups, int width) + : Primitive(stream), groups_(groups), width_(width){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(QuantizedMatmul) + bool is_equivalent(const Primitive& other) const override; + + private: + int groups_; + int width_; + + void eval(const std::vector& inputs, array& out); +}; + class RandomBits : public Primitive { public: explicit RandomBits(Stream stream, const std::vector& shape, int width) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 14b281d82..39dbad5b8 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2977,4 +2977,36 @@ void init_ops(py::module_& m) { Returns: result (array): An array of the same type as ``a`` rounded to the given number of decimals. )pbdoc"); + m.def( + "quantized_matmul", + &quantized_matmul, + "x"_a, + "w"_a, + py::pos_only(), + "scales"_a, + "biases"_a, + "groups"_a = 128, + "width"_a = 4, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + quantized_matmul(x: array, w: array, scales: array, biases: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array + + Perform the matrix multiplication with the quantized matrix ``w``. The + quantization uses one floating point scale and bias per ``groups`` of + elements. Each element in ``w`` takes ``width`` bits and is packed in an + unsigned 32 bit integer. + + Args: + x (array): Input array + w (array): Quantized matrix packed in unsigned integers + scales (array): The scales to use per ``groups`` elements of ``w`` + biases (array): The biases to use per ``groups`` elements of ``w`` + groups (int): The size of the group in ``w`` that shares a scale and + bias. (default: 128) + width (int): The bitwidth of the elements in ``w``. (default: 4) + + Returns: + result (array): The result of the multiplication of ``x`` with ``w``. + )pbdoc"); } diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py new file mode 100644 index 000000000..48493df26 --- /dev/null +++ b/python/tests/test_quantized.py @@ -0,0 +1,112 @@ +# Copyright © 2023 Apple Inc. + +import unittest + +import mlx.core as mx +import mlx_tests + + +def select_bits(w, width, start): + shift_left = 32 - (start + width) + shift_right = shift_left + start + return (w * (2**shift_left)) // (2**shift_right) + + +def dequantize(w, scales, biases, width): + w_full = mx.concatenate( + [select_bits(w, width, i)[..., None] for i in range(0, 32, width)], axis=-1 + ) + w_full = w_full.reshape(len(w), scales.shape[-1], -1) + w_full = scales[..., None] * w_full + biases[..., None] + w_full = w_full.reshape(len(w), -1) + + return w_full + + +def quantize(w, width, groups): + w = w.reshape(len(w), -1, groups) + w_max = w.max(-1, keepdims=True) + w_min = w.min(-1, keepdims=True) + delta = (w_max - w_min) / (2**width - 1) + + w_int = mx.round((w - w_min) / delta).astype(mx.uint32) + scales = delta.squeeze(-1) + biases = w_min.squeeze(-1) + + shifts = mx.array([2**i for i in range(0, 32, width)], dtype=mx.uint32) + w_int = w_int.reshape(len(w), -1, 32 // width) + w_int = w_int * shifts[None, None] + packed_w = w_int.sum(-1) + + return packed_w, scales, biases + + +class TestQuantized(mlx_tests.MLXTestCase): + def test_quantize_dequantize(self): + w = mx.random.normal(shape=(128, 128)) + w_q, scales, biases = quantize(w, 4, 64) + w_hat = dequantize(w_q, scales, biases, 4) + w_hat2 = dequantize(*quantize(w_hat, 4, 64), 4) + self.assertLess((w_hat - w_hat2).abs().max(), 1e-6) + + def test_qmm(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + for groups in [128, 64]: + for width in [2, 4, 8]: + for M in [8, 32, 33, 64]: + for N in [512, 1024]: + for K in [512, 1024]: + with self.subTest( + shape=(M, N, K), groups=groups, width=width + ): + x = mx.random.normal(shape=(M, K), key=k1) + w = mx.random.normal(shape=(N, K), key=k2) + w_q, scales, biases = quantize(w, width, groups) + w_hat = dequantize(w_q, scales, biases, width) + y_q = mx.quantized_matmul( + x, w_q.T, scales, biases, width=width, groups=groups + ) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 0.1) + + def test_qmm_shapes(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + groups = 64 + width = 4 + w = mx.random.normal(shape=(32, 128), key=k2) + w_q, scales, biases = quantize(w, width, groups) + w_hat = dequantize(w_q, scales, biases, width) + for s in [(3, 128), (2, 1, 7, 128)]: + x = mx.random.normal(shape=(3, 128), key=k1) + y_q = mx.quantized_matmul( + x, w_q.T, scales, biases, width=width, groups=groups + ) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 0.1) + + def test_qmv(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + for groups in [128, 64]: + for width in [2, 4, 8]: + for M in [512, 1024]: + for N in [512, 1024]: + # with self.subTest(shape=(M, N), groups=groups, width=width): + x = mx.random.normal(shape=(1, N), key=k1) + w = mx.random.normal(shape=(M, N), key=k2) + w_q, scales, biases = quantize(w, width, groups) + w_hat = dequantize(w_q, scales, biases, width) + y_q = mx.quantized_matmul( + x, w_q.T, scales, biases, width=width, groups=groups + ) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 0.1) + + +if __name__ == "__main__": + unittest.main()