Support for quantized matmul with w and w^T (#349)

* Add the metal qvm implementation
* Add qmm_n
* Add gradient wrt to input for quantized_matmul
This commit is contained in:
Angelos Katharopoulos 2024-01-03 14:22:36 -08:00 committed by GitHub
parent d7ac050f4b
commit e7f5059fe4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 718 additions and 193 deletions

View File

@ -4,6 +4,7 @@ import argparse
import math
import os
import time
from functools import partial
import mlx.core as mx
import mlx.nn as nn
@ -59,15 +60,23 @@ def matmul(x, y):
mx.eval(ys)
def quant_matmul(x, w, s, b):
groups = x.shape[-1] // s.shape[-1]
width = 32 // (x.shape[-1] // w.shape[0])
def _quant_matmul(x, w, s, b, group_size, bits):
ys = []
for i in range(10):
ys.append(mx.quantized_matmul(x, w, s, b, groups=groups, width=width))
ys.append(mx.quantized_matmul(x, w, s, b, group_size=group_size, bits=bits))
mx.eval(ys)
quant_matmul = {
"quant_matmul_64_2": partial(_quant_matmul, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, group_size=64, bits=8),
"quant_matmul_128_2": partial(_quant_matmul, group_size=128, bits=2),
"quant_matmul_128_4": partial(_quant_matmul, group_size=128, bits=4),
"quant_matmul_128_8": partial(_quant_matmul, group_size=128, bits=8),
}
def conv1d(x, y):
ys = []
for i in range(10):
@ -356,8 +365,8 @@ if __name__ == "__main__":
elif args.benchmark == "matmul":
print(bench(matmul, *xs))
elif args.benchmark == "quant_matmul":
print(bench(quant_matmul, *xs))
elif args.benchmark.startswith("quant_matmul"):
print(bench(quant_matmul[args.benchmark], *xs))
elif args.benchmark == "linear":
print(bench(linear, *xs))

View File

@ -76,20 +76,16 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& scales = inputs[2];
auto& biases = inputs[3];
if (w.strides()[0] != 1) {
throw std::runtime_error("The quantized weight should be transposed");
}
bool condition =
(transpose_ && x.flags().row_contiguous && w.flags().row_contiguous &&
scales.flags().row_contiguous && biases.flags().row_contiguous &&
x.dtype() == float32 && bits_ == 4 && group_size_ == 64);
if (!x.flags().row_contiguous || !scales.flags().row_contiguous ||
!biases.flags().row_contiguous) {
throw std::runtime_error("x, scales and biases should be row contiguous.");
}
if (x.dtype() == float32 && bits_ == 4 && group_size_ == 64) {
if (condition) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
int K = x.shape(-1);
int M = x.size() / K;
int N = w.shape(1);
int N = out.shape(-1);
_qmm_t_4_64(
out.data<float>(),
x.data<float>(),

View File

@ -1,13 +1,62 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/backend/metal/copy.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, int bits, int group_size>
void _qmm(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Ng = N / group_size;
const int Nw = N / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
std::fill(result, result + N, 0);
for (int k = 0; k < K; k++) {
T* result_local = result;
T xi = *x++;
for (int n = 0; n < N; n += group_size) {
T scale = *scales_local++;
T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) {
uint32_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= bits;
}
}
}
}
result += N;
}
}
template <typename T, int bits, int group_size>
void _qmm_t(
T* result,
@ -55,7 +104,7 @@ void _qmm_t(
}
template <typename T>
void _qmm_t_dispatch_typed(
void _qmm_dispatch_typed(
T* result,
const T* x,
const uint32_t* w,
@ -65,30 +114,55 @@ void _qmm_t_dispatch_typed(
int N,
int K,
int group_size,
int bits) {
int bits,
bool transposed_w) {
switch (bits) {
case 2: {
switch (group_size) {
case 64:
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
if (transposed_w) {
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
if (transposed_w) {
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 4: {
switch (group_size) {
case 64:
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
if (transposed_w) {
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 8: {
switch (group_size) {
case 64:
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
if (transposed_w) {
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
}
@ -100,21 +174,22 @@ void _qmm_t_dispatch_typed(
throw std::invalid_argument(msg.str());
}
void _qmm_t_dispatch(
void _qmm_dispatch(
array out,
const array& x,
const array& w,
const array& scales,
const array& biases,
int bits,
int group_size) {
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.size() / K;
int N = w.shape(1);
int N = out.shape(-1);
switch (x.dtype()) {
case float32:
_qmm_t_dispatch_typed<float>(
_qmm_dispatch_typed<float>(
out.data<float>(),
x.data<float>(),
w.data<uint32_t>(),
@ -124,10 +199,11 @@ void _qmm_t_dispatch(
N,
K,
bits,
group_size);
group_size,
transposed_w);
break;
case float16:
_qmm_t_dispatch_typed<float16_t>(
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>(),
x.data<float16_t>(),
w.data<uint32_t>(),
@ -137,10 +213,11 @@ void _qmm_t_dispatch(
N,
K,
bits,
group_size);
group_size,
transposed_w);
break;
case bfloat16:
_qmm_t_dispatch_typed<bfloat16_t>(
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>(),
x.data<bfloat16_t>(),
w.data<uint32_t>(),
@ -150,7 +227,8 @@ void _qmm_t_dispatch(
N,
K,
bits,
group_size);
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
@ -163,22 +241,28 @@ void _qmm_t_dispatch(
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x = inputs[0];
auto& w = inputs[1];
auto& scales = inputs[2];
auto& biases = inputs[3];
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
if (w.strides()[0] != 1) {
throw std::runtime_error("The quantized weight should be transposed");
}
auto ensure_row_contiguous = [](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
if (!x.flags().row_contiguous || !scales.flags().row_contiguous ||
!biases.flags().row_contiguous) {
throw std::runtime_error("x, scales and biases should be row contiguous.");
}
auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_t_dispatch(out, x, w, scales, biases, group_size_, bits_);
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
}
} // namespace mlx::core

View File

@ -104,6 +104,108 @@ template <typename T, const int BM, const int BN, const int group_size, const in
}
template <typename T, const int BM, const int BN, const int group_size, const int bits>
[[kernel]] void qvm(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BM == SIMD_SIZE, "qvm expects BM to be equal to SIMD_SIZE");
static_assert(BN == BM, "qvm expects a block size of 32x32");
(void)lid;
constexpr int bitmask = (1 << bits) - 1;
constexpr int el_per_int = 32 / bits;
constexpr int colgroup = BN * el_per_int;
constexpr int groups_per_block = colgroup / group_size;
threadgroup T scales_block[BM * groups_per_block];
threadgroup T biases_block[BM * groups_per_block];
threadgroup T x_block[BM];
thread uint32_t w_local;
thread T result[el_per_int] = {0};
thread T scale = 1;
thread T bias = 0;
thread T x_local = 0;
// Adjust positions
const int out_vec_size_w = out_vec_size / el_per_int;
const int out_vec_size_g = out_vec_size / group_size;
int out_col = (tid.y * BN + simd_gid) * el_per_int;
w += out_col / el_per_int;
scales += out_col / group_size;
biases += out_col / group_size;
x += tid.z * in_vec_size;
y += tid.z * out_vec_size + out_col;
if (out_col >= out_vec_size) {
return;
}
// Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=BM) {
// Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid == 0) {
x_block[simd_lid] = x[simd_lid + i];
}
// Load the scales and biases to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid == 0) {
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
scales_block[simd_lid * groups_per_block + j] = scales[(i + simd_lid) * out_vec_size_g + j];
}
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
biases_block[simd_lid * groups_per_block + j] = biases[(i + simd_lid) * out_vec_size_g + j];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load in_vec, scale, bias to registers
x_local = x_block[simd_lid];
scale = scales_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
// Load the matrix elements
w_local = w[(i + simd_lid) * out_vec_size_w];
// Do all the work.
#pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) {
result[k] += (scale * static_cast<T>(w_local & bitmask) + bias) * x_local;
w_local >>= bits;
}
}
// Accumulate in the simdgroup
#pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) {
result[k] = simd_sum(result[k]);
}
// Store the result
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) {
y[k] = result[k];
}
}
}
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
[[kernel]] void qmm_t(
const device T* x [[buffer(0)]],
@ -133,8 +235,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
constexpr int groups_per_simd = BN / (WM * WN);
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
// Using the kernel just as a type to instantiate the appropriate BlockMMA
// and constexpr size calculations
// Instantiate the appropriate BlockMMA and Loader
using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, true>;
using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>;
@ -231,8 +332,133 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
}
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
[[kernel]] void qmm_n(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& M [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& K [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
const uint lidy = lid / SIMD_SIZE;
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int bitmask = (1 << bits) - 1;
constexpr int el_per_int = 32 / bits;
constexpr int groups_per_block = (BN / group_size > 0) ? (BN / group_size) : 1;
constexpr int groups_per_simd = BK / (WM * WN);
constexpr int w_els_per_thread = (BK * BN / el_per_int) / (SIMD_SIZE * WM * WN);
// Instantiate the appropriate BlockMMA and Loader
using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, false>;
using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>;
threadgroup T scales_block[BK * groups_per_block];
threadgroup T biases_block[BK * groups_per_block];
threadgroup T Xs[BM * BK];
threadgroup T Ws[BK * BN];
// Set the block
const int N_w = N / el_per_int;
const int N_g = N / group_size;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
x += y_row * K;
w += y_col / el_per_int;
scales += y_col / group_size;
biases += y_col / group_size;
y += y_row * N + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
for (int k=0; k<K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the x tile
if (num_els < BM) {
loader_x.load_safe(short2(BK, num_els));
} else {
loader_x.load_unsafe();
}
// Load the scale and bias
if (simd_lid == 0) {
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd;
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd;
const device T *scales_local = scales + lidy * groups_per_simd * N_g;
const device T *biases_local = biases + lidy * groups_per_simd * N_g;
#pragma clang loop unroll(full)
for (int gs=0; gs<groups_per_simd; gs++) {
#pragma clang loop unroll(full)
for (int gc=0; gc<groups_per_block; gc++) {
scales_block_local[gc] = scales_local[gc];
biases_block_local[gc] = biases_local[gc];
}
scales_block_local += groups_per_block;
scales_local += N_g;
biases_block_local += groups_per_block;
biases_local += N_g;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the w tile
{
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);
int offset_col = offset % (BN / el_per_int);
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(Xs, Ws);
// Prepare for next iteration
loader_x.next();
w += BK * N_w;
scales += BK * N_g;
biases += BK * N_g;
}
// Store results to device memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (num_els < BM) {
mma_op.store_result_safe(y, N, short2(BN, num_els));
} else {
mma_op.store_result(y, N);
}
}
#define instantiate_qmv(name, itype, group_size, bits) \
template [[host_name("qmv_n_" #name "_gs_" #group_size "_b_" #bits)]] \
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmv<itype, 32, 32, group_size, bits>( \
const device uint32_t* w [[buffer(0)]], \
const device itype* scales [[buffer(1)]], \
@ -258,6 +484,33 @@ instantiate_qmv_types( 64, 2)
instantiate_qmv_types( 64, 4)
instantiate_qmv_types( 64, 8)
#define instantiate_qvm(name, itype, group_size, bits) \
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qvm<itype, 32, 32, group_size, bits>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
const device itype* biases [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& in_vec_size [[buffer(5)]], \
const constant int& out_vec_size [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qvm_types(group_size, bits) \
instantiate_qvm(float32, float, group_size, bits) \
instantiate_qvm(float16, half, group_size, bits) \
instantiate_qvm(bfloat16, bfloat16_t, group_size, bits)
instantiate_qvm_types(128, 2)
instantiate_qvm_types(128, 4)
instantiate_qvm_types(128, 8)
instantiate_qvm_types( 64, 2)
instantiate_qvm_types( 64, 4)
instantiate_qvm_types( 64, 8)
#define instantiate_qmm_t(name, itype, group_size, bits) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits>( \
@ -285,3 +538,31 @@ instantiate_qmm_t_types(128, 8)
instantiate_qmm_t_types( 64, 2)
instantiate_qmm_t_types( 64, 4)
instantiate_qmm_t_types( 64, 8)
#define instantiate_qmm_n(name, itype, group_size, bits) \
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmm_n<itype, 32, 32, 64, group_size, bits>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
const device itype* biases [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& M [[buffer(5)]], \
const constant int& N [[buffer(6)]], \
const constant int& K [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmm_n_types(group_size, bits) \
instantiate_qmm_n(float32, float, group_size, bits) \
instantiate_qmm_n(float16, half, group_size, bits) \
instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits)
instantiate_qmm_n_types(128, 2)
instantiate_qmm_n_types(128, 4)
instantiate_qmm_n_types(128, 8)
instantiate_qmm_n_types( 64, 2)
instantiate_qmm_n_types( 64, 4)
instantiate_qmm_n_types( 64, 8)

View File

@ -1,7 +1,6 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
@ -23,97 +22,147 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& biases_pre = inputs[3];
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
return arr_copy;
}
};
auto [x_transposed, x_cols, x] = check_transpose(x_pre);
auto [w_transposed, w_cols, w] = check_transpose(w_pre);
auto [scales_transposed, scales_cols, scales] = check_transpose(scales_pre);
auto [biases_transposed, biases_cols, biases] = check_transpose(biases_pre);
if (!w_transposed) {
throw std::runtime_error("The quantized weight should be transposed.");
}
if (x_transposed || scales_transposed || biases_transposed) {
throw std::runtime_error("x, scales and biases should be row contiguous.");
}
auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
int D = x.shape(-1);
int B = x.size() / D;
int O = out.shape(-1);
if (transpose_) {
// Route to the qmv kernel
if (B < 6) {
std::ostringstream kname;
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
// Route to the qmv kernel
if (B == 1) {
std::ostringstream kname;
kname << "qmv_" << (w_transposed ? "n_" : "t_") << type_to_name(out)
<< "_gs_" << group_size_ << "_b_" << bits_;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = 32;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
int O = w.size() / w_cols;
set_array_buffer(compute_encoder, w, 0);
set_array_buffer(compute_encoder, scales, 1);
set_array_buffer(compute_encoder, biases, 2);
set_array_buffer(compute_encoder, x, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
int bo = 32;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
set_array_buffer(compute_encoder, w, 0);
set_array_buffer(compute_encoder, scales, 1);
set_array_buffer(compute_encoder, biases, 2);
set_array_buffer(compute_encoder, x, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
// Route to the qmm_t kernel
else {
std::ostringstream kname;
kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Route to the qmm kernel
else {
std::ostringstream kname;
kname << "qmm_" << (w_transposed ? "t_" : "n_") << type_to_name(out)
<< "_gs_" << group_size_ << "_b_" << bits_;
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 32;
int bk = 64;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, scales, 2);
set_array_buffer(compute_encoder, biases, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&B, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
int O = w.size() / w_cols;
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
} else {
// Route to the qvm kernel
if (B < 4) {
std::ostringstream kname;
kname << "qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 32;
int bk = 64;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, scales, 2);
set_array_buffer(compute_encoder, biases, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&B, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
int bo = 32;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, (w.shape(1) + bo - 1) / bo, B);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, scales, 2);
set_array_buffer(compute_encoder, biases, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmm_n kernel
else {
std::ostringstream kname;
kname << "qmm_n_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 64;
int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
if ((O % bn) != 0) {
std::ostringstream msg;
msg << "[quantized_matmul] The output size should be divisible by "
<< bn << " but received " << O << ".";
throw std::runtime_error(msg.str());
}
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, scales, 2);
set_array_buffer(compute_encoder, biases, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&B, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
}
d.get_command_buffer(s.index)->addCompletedHandler(

View File

@ -2618,10 +2618,11 @@ array quantized_matmul(
const array& w,
const array& scales,
const array& biases,
bool transpose /* = true */,
int group_size /* = 64 */,
int bits /* = 4 */,
StreamOrDevice s /* = {} */) {
auto x = in_x;
array x = in_x;
if (w.dtype() != uint32) {
std::ostringstream msg;
@ -2646,39 +2647,52 @@ array quantized_matmul(
x = reshape(x, {-1, x_inner_dims}, s);
}
int w_inner_dims = w.shape(0) * (32 / bits);
if (w_inner_dims != x_inner_dims) {
if (scales.ndim() != 2 || scales.shape() != biases.shape()) {
std::ostringstream msg;
msg << "[quantized_matmul] Last dimension of first input with "
<< "shape (..., " << x_inner_dims
<< ") does not match the expanded first "
<< "dimension of the quantized matrix " << w_inner_dims
<< ", computed from shape " << w.shape()
msg << "[quantized_matmul] Scales and biases should have the same 2D shape. "
<< "Received scales with shape " << scales.shape()
<< " and biases with " << biases.shape();
throw std::invalid_argument(msg.str());
}
if (w.shape(1) * 32 / bits != scales.shape(1) * group_size) {
std::ostringstream msg;
msg << "[quantized_matmul] The shapes of the weight and scales are "
<< "incompatible based on bits and group_size. w.shape() == "
<< w.shape() << " and scales.shape() == " << scales.shape()
<< " with group_size=" << group_size << " and bits=" << bits;
throw std::invalid_argument(msg.str());
}
int n_groups = x_inner_dims / group_size;
if (scales.shape(-1) != n_groups || biases.shape(-1) != n_groups) {
// Calculate the expanded w's dims
int w_inner_dims = (transpose) ? w.shape(1) * 32 / bits : w.shape(0);
int w_outer_dims = (transpose) ? w.shape(0) : w.shape(1) * 32 / bits;
if (w_inner_dims != x_inner_dims) {
std::ostringstream msg;
msg << "[quantized_matmul] Scales and biases provided do not match the "
<< "quantization arguments (group_size=" << group_size
<< ", bits=" << bits << "). Expected shapes (" << w.shape(1) << ", "
<< x_inner_dims / group_size
<< "), but got scales.shape=" << scales.shape()
<< " and biases.shape=" << biases.shape();
msg << "[quantized_matmul] Last dimension of first input with "
<< "shape (..., " << x_inner_dims << ") does not match "
<< "the expanded quantized matrix (" << w_inner_dims << ", "
<< w_outer_dims << ") computed from shape " << w.shape()
<< " with group_size=" << group_size << ", bits=" << bits
<< " and transpose=" << std::boolalpha << transpose;
throw std::invalid_argument(msg.str());
}
auto dtype = result_type({x, scales, biases});
auto out = array(
{x.shape(0), w.shape(1)},
x.dtype(),
std::make_unique<QuantizedMatmul>(to_stream(s), group_size, bits),
{x, w, scales, biases});
{x.shape(0), w_outer_dims},
dtype,
std::make_unique<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose),
{astype(x, dtype, s),
w,
astype(scales, dtype, s),
astype(biases, dtype, s)});
// If needed reshape x to the original batch shape
if (original_shape.size() != 1) {
original_shape.push_back(w.shape(1));
original_shape.push_back(w_outer_dims);
out = reshape(out, original_shape, s);
}

View File

@ -1041,6 +1041,7 @@ array quantized_matmul(
const array& w,
const array& scales,
const array& biases,
bool transpose = true,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});

View File

@ -1706,14 +1706,37 @@ std::vector<array> QuantizedMatmul::vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) {
throw std::runtime_error("QuantizedMatmul::vjp NYI");
std::vector<array> vjps;
// We rely on the fact that w is always 2D so transpose is simple
for (auto arg : argnums) {
// gradient wrt to x
if (arg == 0) {
vjps.push_back(quantized_matmul(
cotan,
primals[1],
primals[2],
primals[3],
!transpose_,
group_size_,
bits_,
stream()));
}
// gradient wrt to w_q, scales or biases
else {
throw std::runtime_error(
"QuantizedMatmul::vjp no gradient wrt the quantized matrix yet.");
}
}
return vjps;
}
array QuantizedMatmul::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
throw std::runtime_error("QuantizedMatmul::vjp NYI");
throw std::runtime_error("QuantizedMatmul::jvp NYI");
}
bool QuantizedMatmul::is_equivalent(const Primitive& other) const {

View File

@ -1112,8 +1112,15 @@ class Power : public Primitive {
class QuantizedMatmul : public Primitive {
public:
explicit QuantizedMatmul(Stream stream, int group_size, int bits)
: Primitive(stream), group_size_(group_size), bits_(bits){};
explicit QuantizedMatmul(
Stream stream,
int group_size,
int bits,
bool transpose)
: Primitive(stream),
group_size_(group_size),
bits_(bits),
transpose_(transpose){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1129,6 +1136,7 @@ class QuantizedMatmul : public Primitive {
private:
int group_size_;
int bits_;
bool transpose_;
void eval(const std::vector<array>& inputs, array& out);
};

View File

@ -81,9 +81,10 @@ class QuantizedLinear(Module):
def __call__(self, x):
x = mx.quantized_matmul(
x,
self.weight.T,
self.weight,
scales=self.scales,
biases=self.biases,
transpose=True,
group_size=self.group_size,
bits=self.bits,
)

View File

@ -3072,12 +3072,13 @@ void init_ops(py::module_& m) {
py::pos_only(),
"scales"_a,
"biases"_a,
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
quantized_matmul(x: array, w: array, scales: array, biases: array, /, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``group_size`` of
@ -3089,10 +3090,13 @@ void init_ops(py::module_& m) {
w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. (default: ``True``)
group_size (int, optional): The size of the group in ``w`` that
shares a scale and bias. (default: 64)
shares a scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element in
``w``. (default: 4)
``w``. (default: ``4``)
Returns:
result (array): The result of the multiplication of ``x`` with ``w``.
@ -3146,9 +3150,9 @@ void init_ops(py::module_& m) {
Args:
w (array): Matrix to be quantized
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: 64)
scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. (default: 4)
``w`` in the returned quantized matrix. (default: ``4``)
Returns:
(tuple): A tuple containing
@ -3187,9 +3191,9 @@ void init_ops(py::module_& m) {
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: 64)
scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element in
``w``. (default: 4)
``w``. (default: ``4``)
Returns:
result (array): The dequantized version of ``w``

View File

@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.
import unittest
from itertools import product
import mlx.core as mx
import mlx_tests
@ -19,62 +20,116 @@ class TestQuantized(mlx_tests.MLXTestCase):
def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
for group_size in [128, 64]:
for bits in [2, 4, 8]:
for M in [8, 32, 33, 64]:
for N in [512, 1024]:
for K in [512, 1024]:
with self.subTest(
shape=(M, N, K), group_size=group_size, bits=bits
):
x = mx.random.normal(shape=(M, K), key=k1)
w = mx.random.normal(shape=(N, K), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(
w_q, scales, biases, group_size, bits
)
y_q = mx.quantized_matmul(
x, w_q.T, scales, biases, group_size, bits
)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
tests = product(
[128, 64], # group_size
[2, 4, 8], # bits
[8, 32, 33, 64], # M
[512, 1024], # N
[512, 1024], # K
[True, False], # transposed
)
for group_size, bits, M, N, K, transposed in tests:
with self.subTest(
shape=(M, N, K),
group_size=group_size,
bits=bits,
transposed=transposed,
):
x = mx.random.normal(shape=(M, K), key=k1)
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, transposed, group_size, bits
)
y_hat = (x @ w_hat.T) if transposed else (x @ w_hat)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qmm_shapes(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
group_size = 64
bits = 4
w = mx.random.normal(shape=(32, 128), key=k2)
w = mx.random.normal(shape=(32, 256), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
for s in [(3, 128), (2, 1, 7, 128)]:
x = mx.random.normal(shape=(3, 128), key=k1)
y_q = mx.quantized_matmul(x, w_q.T, scales, biases, group_size, bits)
for s in [(3, 256), (2, 1, 7, 256)]:
x = mx.random.normal(shape=s, key=k1)
y_q = mx.quantized_matmul(x, w_q, scales, biases, True, group_size, bits)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
w = mx.random.normal(shape=(256, 256), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
for s in [(3, 256), (2, 1, 7, 256)]:
x = mx.random.normal(shape=s, key=k1)
y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qmv(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
for group_size in [128, 64]:
for bits in [2, 4, 8]:
for M in [512, 1024]:
for N in [512, 1024]:
with self.subTest(
shape=(M, N), group_size=group_size, bits=bits
):
x = mx.random.normal(shape=(1, N), key=k1)
w = mx.random.normal(shape=(M, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q.T, scales, biases, group_size, bits
)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
tests = product(
[128, 64], # group_size
[2, 4, 8], # bits
[512, 1024], # M
[512, 1024], # N
)
for group_size, bits, M, N in tests:
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
x = mx.random.normal(shape=(1, N), key=k1)
w = mx.random.normal(shape=(M, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, True, group_size, bits
)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qvm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64], # group_size
[2, 4, 8], # bits
[512, 1024], # M
[512, 1024], # N
)
for group_size, bits, M, N in tests:
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
x = mx.random.normal(shape=(1, N), key=k1)
w = mx.random.normal(shape=(N, M), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, False, group_size, bits
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_throw(self):
x = mx.random.normal(shape=(10, 512))
w = mx.random.normal(shape=(32, 512))
w_q, scales, biases = mx.quantize(w)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q.T, scales, biases)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q.T, scales.T, biases)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q, scales, biases, False)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q, scales.T, biases.T)
y = mx.quantized_matmul(x, w_q, scales, biases, True)
mx.eval(y)
if __name__ == "__main__":