An initial quantized matmul implementation (#205)

* Add quantized matvec
* Add quantized matrix matrix with 2nd matrix transposed
* Add quantized matmul tests
* Add a slow cpu quantized matmul
* Add a slightly faster vectorized cpu version
This commit is contained in:
Angelos Katharopoulos 2023-12-18 23:18:57 -08:00 committed by GitHub
parent e6872a4149
commit dfa9f4bc58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1029 additions and 10 deletions

View File

@ -23,6 +23,16 @@ def none_or_list(x):
return [int(xi) for xi in x.split(",")]
def dtype_from_str(x):
if x == "":
return mx.float32
else:
dt = getattr(mx, x)
if not isinstance(dt, mx.Dtype):
raise ValueError(f"{x} is not an mlx dtype")
return dt
def bench(f, *args):
for i in range(10):
f(*args)
@ -49,6 +59,15 @@ def matmul(x, y):
mx.eval(ys)
def quant_matmul(x, w, s, b):
groups = x.shape[-1] // s.shape[-1]
width = 32 // (x.shape[-1] // w.shape[0])
ys = []
for i in range(10):
ys.append(mx.quantized_matmul(x, w, s, b, groups=groups, width=width))
mx.eval(ys)
def conv1d(x, y):
ys = []
for i in range(10):
@ -296,9 +315,7 @@ if __name__ == "__main__":
parser.add_argument(
"--fused", action="store_true", help="Use fused functions where possible"
)
parser.add_argument(
"--dtype", choices=["float32", "float16", "bfloat16"], default="float32"
)
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
args = parser.parse_args()
@ -315,11 +332,15 @@ if __name__ == "__main__":
mx.set_default_device(mx.cpu)
else:
mx.set_default_device(mx.gpu)
dtype = dict(float32=mx.float32, float16=mx.float16, bfloat16=mx.bfloat16)[
args.dtype
]
types = args.dtype
if not types:
types = [mx.float32]
if len(types) < len(args.size):
types = types + [types[0]] * (len(args.size) - len(types))
xs = []
for size in args.size:
for size, dtype in zip(args.size, types):
xs.append(mx.random.normal(size).astype(dtype))
for i, t in enumerate(args.transpose):
if t is None:
@ -335,6 +356,9 @@ if __name__ == "__main__":
elif args.benchmark == "matmul":
print(bench(matmul, *xs))
elif args.benchmark == "quant_matmul":
print(bench(quant_matmul, *xs))
elif args.benchmark == "linear":
print(bench(linear, *xs))

View File

@ -22,6 +22,16 @@ def none_or_list(x):
return [int(xi) for xi in x.split(",")]
def dtype_from_str(x):
if x == "":
return torch.float32
else:
dt = getattr(torch, x)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{x} is not a torch dtype")
return dt
def bench(f, *args):
for i in range(10):
f(*args)
@ -312,7 +322,7 @@ if __name__ == "__main__":
parser.add_argument(
"--fused", action="store_true", help="Use fused functions where possible"
)
parser.add_argument("--dtype", choices=["float32", "float16"], default="float32")
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
args = parser.parse_args()
@ -327,9 +337,15 @@ if __name__ == "__main__":
torch.set_num_threads(1)
device = "cpu" if args.cpu else "mps"
dtype = dict(float32=torch.float32, float16=torch.float16)[args.dtype]
types = args.dtype
if not types:
types = [torch.float32]
if len(types) < len(args.size):
types = types + [types[0]] * (len(args.size) - len(types))
xs = []
for size in args.size:
for size, dtype in zip(args.size, types):
xs.append(torch.randn(*size).to(device).to(dtype))
for i, t in enumerate(args.transpose):
if t is None:

View File

@ -4,6 +4,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
)

View File

@ -0,0 +1,107 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <simd/vector.h>
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
void _qmm_t_4_64(
float* result,
const float* x,
const uint32_t* w,
const float* scales,
const float* biases,
int M,
int N,
int K) {
constexpr int width = 4;
constexpr int groups = 64;
constexpr int bitmask = (1 << width) - 1;
constexpr int pack_factor = 32 / width;
constexpr int packs_in_group = groups / pack_factor;
const int Kg = K / groups;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const float* scales_local = scales;
const float* biases_local = biases;
for (int n = 0; n < N; n++) {
const simd_float16* x_local = (simd_float16*)x;
simd_float16 sum = 0;
for (int k = 0; k < K; k += groups) {
float scale = *scales_local++;
float bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw += 2) {
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= width;
}
}
simd_float16 wf = simd_float(wi);
wf *= scale;
wf += bias;
sum += (*x_local) * wf;
x_local++;
}
}
*result = simd_reduce_add(sum);
result++;
}
x += K;
}
}
} // namespace
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x = inputs[0];
auto& w = inputs[1];
auto& scales = inputs[2];
auto& biases = inputs[3];
if (w.strides()[0] != 1) {
throw std::runtime_error("The quantized weight should be transposed");
}
if (!x.flags().row_contiguous || !scales.flags().row_contiguous ||
!biases.flags().row_contiguous) {
throw std::runtime_error("x, scales and biases should be row contiguous.");
}
if (x.dtype() == float32 && width_ == 4 && groups_ == 64) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
int K = x.shape(-1);
int M = x.size() / K;
int N = w.shape(1);
_qmm_t_4_64(
out.data<float>(),
x.data<float>(),
w.data<uint32_t>(),
scales.data<float>(),
biases.data<float>(),
M,
N,
K);
} else {
eval(inputs, out);
}
}
} // namespace mlx::core

View File

@ -8,6 +8,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp

View File

@ -62,6 +62,7 @@ DEFAULT(NotEqual)
DEFAULT(Pad)
DEFAULT(Partition)
DEFAULT(Power)
DEFAULT(QuantizedMatmul)
DEFAULT(RandomBits)
DEFAULT(Reduce)
DEFAULT(Reshape)

View File

@ -0,0 +1,183 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, int width, int groups>
void _qmm_t(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K) {
constexpr int bitmask = (1 << width) - 1;
constexpr int pack_factor = 32 / width;
constexpr int packs_in_group = groups / pack_factor;
const int Kg = K / groups;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
for (int n = 0; n < N; n++) {
const T* x_local = x;
T sum = 0;
for (int k = 0; k < K; k += groups) {
T scale = *scales_local++;
T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) {
uint32_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= width;
}
}
}
*result = sum;
result++;
}
x += K;
}
}
template <typename T>
void _qmm_t_dispatch_typed(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
int width,
int groups) {
switch (width) {
case 2: {
switch (groups) {
case 64:
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
case 128:
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
}
}
case 4: {
switch (groups) {
case 64:
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
case 128:
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
}
}
case 8: {
switch (groups) {
case 64:
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
case 128:
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
std::ostringstream msg;
msg << "Quantization type not supported. Provided bit width=" << width
<< " and groups=" << groups << ". The supported options are width in "
<< "{2, 4, 8} and groups in {64, 128}.";
throw std::invalid_argument(msg.str());
}
void _qmm_t_dispatch(
array out,
const array& x,
const array& w,
const array& scales,
const array& biases,
int width,
int groups) {
int K = x.shape(-1);
int M = x.size() / K;
int N = w.shape(1);
switch (x.dtype()) {
case float32:
_qmm_t_dispatch_typed<float>(
out.data<float>(),
x.data<float>(),
w.data<uint32_t>(),
scales.data<float>(),
biases.data<float>(),
M,
N,
K,
width,
groups);
break;
case float16:
_qmm_t_dispatch_typed<float16_t>(
out.data<float16_t>(),
x.data<float16_t>(),
w.data<uint32_t>(),
scales.data<float16_t>(),
biases.data<float16_t>(),
M,
N,
K,
width,
groups);
break;
case bfloat16:
_qmm_t_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>(),
x.data<bfloat16_t>(),
w.data<uint32_t>(),
scales.data<bfloat16_t>(),
biases.data<bfloat16_t>(),
M,
N,
K,
width,
groups);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
} // namespace
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
auto& x = inputs[0];
auto& w = inputs[1];
auto& scales = inputs[2];
auto& biases = inputs[3];
if (w.strides()[0] != 1) {
throw std::runtime_error("The quantized weight should be transposed");
}
if (!x.flags().row_contiguous || !scales.flags().row_contiguous ||
!biases.flags().row_contiguous) {
throw std::runtime_error("x, scales and biases should be row contiguous.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_t_dispatch(out, x, w, scales, biases, width_, groups_);
}
} // namespace mlx::core

View File

@ -10,6 +10,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp

View File

@ -18,6 +18,7 @@ set(
"copy"
"gemm"
"gemv"
"quantized"
"random"
"reduce"
"scan"

View File

@ -0,0 +1,287 @@
// Copyright © 2023 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
MLX_MTL_CONST int SIMD_SIZE = 32;
template <typename T, const int BM, const int BN, const int groups, const int width>
[[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
constexpr int bitmask = (1 << width) - 1;
constexpr int el_per_thread = 32 / width;
constexpr int colgroup = BN * el_per_thread;
constexpr int groups_per_block = colgroup / groups;
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
threadgroup T scales_block[BM * groups_per_block];
threadgroup T biases_block[BM * groups_per_block];
threadgroup T x_block[colgroup];
thread uint32_t w_local;
thread T result = 0;
thread T scale = 1;
thread T bias = 0;
thread T x_thread[el_per_thread];
// Adjust positions
const int in_vec_size_w = in_vec_size / el_per_thread;
const int in_vec_size_g = in_vec_size / groups;
int out_row = tid.y * BM + simd_gid;
w += out_row * in_vec_size_w;
scales += out_row * in_vec_size_g;
biases += out_row * in_vec_size_g;
x += tid.z * in_vec_size;
y += tid.z * out_vec_size;
// Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=colgroup) {
// Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid < simdgroups_fetching_vec) {
x_block[lid] = x[lid + i];
}
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
scales_block[simd_gid * groups_per_block + j] = scales[i / groups + j];
}
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
biases_block[simd_gid * groups_per_block + j] = biases[i / groups + j];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load in_vec, scale, bias to registers
#pragma clang loop unroll(full)
for (int j=0; j<el_per_thread; j++) {
x_thread[j] = x_block[simd_lid*el_per_thread + j];
}
scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups];
bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups];
// Load the matrix elements
w_local = w[i / el_per_thread + simd_lid];
// Do all the work.
#pragma clang loop unroll(full)
for (int k=0; k<el_per_thread; k++) {
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
w_local >>= width;
}
}
// Accumulate in the simdgroup
result = simd_sum(result);
// Store the result
if (simd_lid == 0) {
y[out_row] = result;
}
}
template <typename T, const int BM, const int BK, const int BN, const int groups, const int width>
[[kernel]] void qmm_t(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& M [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& K [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
const uint lidy = lid / SIMD_SIZE;
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int bitmask = (1 << width) - 1;
constexpr int el_per_int = 32 / width;
constexpr int ints_per_block = BK / el_per_int;
constexpr int groups_per_block = (BK / groups > 0) ? (BK / groups) : 1;
constexpr int groups_per_simd = BN / (WM * WN);
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
// Using the kernel just as a type to instantiate the appropriate BlockMMA
// and constexpr size calculations
using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, true>;
using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>;
threadgroup T scales_block[BN * groups_per_block];
threadgroup T biases_block[BN * groups_per_block];
threadgroup T Xs[BM * BK];
threadgroup T Ws[BN * BK];
// Set the block
const int K_w = K / el_per_int;
const int K_g = K / groups;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
x += y_row * K;
w += y_col * K_w;
scales += y_col * K_g;
biases += y_col * K_g;
y += y_row * N + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
for (int k=0; k<K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the x tile
if (num_els < BM) {
loader_x.load_safe(short2(BK, num_els));
} else {
loader_x.load_unsafe();
}
// Load the scale and bias
if (simd_lid == 0) {
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd;
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd;
const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / groups;
const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / groups;
#pragma clang loop unroll(full)
for (int gs=0; gs<groups_per_simd; gs++) {
#pragma clang loop unroll(full)
for (int gc=0; gc<groups_per_block; gc++) {
scales_block_local[gc] = scales_local[gc];
biases_block_local[gc] = biases_local[gc];
}
scales_block_local += groups_per_block;
scales_local += K_g;
biases_block_local += groups_per_block;
biases_local += K_g;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the w tile
{
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BK / el_per_int);
int offset_col = offset % (BK / el_per_int);
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= width;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(Xs, Ws);
// Prepare for next iteration
loader_x.next();
w += ints_per_block;
// scales and biases cannot be advanced because they would have to be
// advanced every other iteration or sth.
}
// Store results to device memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (num_els < BM) {
mma_op.store_result_safe(y, N, short2(BN, num_els));
} else {
mma_op.store_result(y, N);
}
}
#define instantiate_qmv(name, itype, groups, width) \
template [[host_name("qmv_n_" #name "_groups_" #groups "_width_" #width)]] \
[[kernel]] void qmv<itype, 32, 32, groups, width>( \
const device uint32_t* w [[buffer(0)]], \
const device itype* scales [[buffer(1)]], \
const device itype* biases [[buffer(2)]], \
const device itype* x [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& in_vec_size [[buffer(5)]], \
const constant int& out_vec_size [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmv_types(groups, width) \
instantiate_qmv(float32, float, groups, width) \
instantiate_qmv(float16, half, groups, width) \
instantiate_qmv(bfloat16, bfloat16_t, groups, width)
instantiate_qmv_types(128, 2)
instantiate_qmv_types(128, 4)
instantiate_qmv_types(128, 8)
instantiate_qmv_types( 64, 2)
instantiate_qmv_types( 64, 4)
instantiate_qmv_types( 64, 8)
#define instantiate_qmm_t(name, itype, groups, width) \
template [[host_name("qmm_t_" #name "_groups_" #groups "_width_" #width)]] \
[[kernel]] void qmm_t<itype, 32, 64, 32, groups, width>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
const device itype* biases [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& M [[buffer(5)]], \
const constant int& N [[buffer(6)]], \
const constant int& K [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmm_t_types(groups, width) \
instantiate_qmm_t(float32, float, groups, width) \
instantiate_qmm_t(float16, half, groups, width) \
instantiate_qmm_t(bfloat16, bfloat16_t, groups, width)
instantiate_qmm_t_types(128, 2)
instantiate_qmm_t_types(128, 4)
instantiate_qmm_t_types(128, 8)
instantiate_qmm_t_types( 64, 2)
instantiate_qmm_t_types( 64, 4)
instantiate_qmm_t_types( 64, 8)

View File

@ -0,0 +1,123 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [x_transposed, x_cols, x] = check_transpose(x_pre);
auto [w_transposed, w_cols, w] = check_transpose(w_pre);
auto [scales_transposed, scales_cols, scales] = check_transpose(scales_pre);
auto [biases_transposed, biases_cols, biases] = check_transpose(biases_pre);
if (!w_transposed) {
throw std::runtime_error("The quantized weight should be transposed.");
}
if (x_transposed || scales_transposed || biases_transposed) {
throw std::runtime_error("x, scales and biases should be row contiguous.");
}
int D = x.shape(-1);
int B = x.size() / D;
// Route to the qmv kernel
if (B == 1) {
std::ostringstream kname;
kname << "qmv_" << (w_transposed ? "n_" : "t_") << type_to_name(out)
<< "_groups_" << groups_ << "_width_" << width_;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int O = w.size() / w_cols;
int bo = 32;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
set_array_buffer(compute_encoder, w, 0);
set_array_buffer(compute_encoder, scales, 1);
set_array_buffer(compute_encoder, biases, 2);
set_array_buffer(compute_encoder, x, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmm kernel
else {
std::ostringstream kname;
kname << "qmm_" << (w_transposed ? "t_" : "n_") << type_to_name(out)
<< "_groups_" << groups_ << "_width_" << width_;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int O = w.size() / w_cols;
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 32;
int bk = 64;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, scales, 2);
set_array_buffer(compute_encoder, biases, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&B, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core

View File

@ -58,6 +58,7 @@ NO_GPU(NotEqual)
NO_GPU(Pad)
NO_GPU(Partition)
NO_GPU(Power)
NO_GPU(QuantizedMatmul)
NO_GPU(RandomBits)
NO_GPU(Reduce)
NO_GPU(Reshape)

View File

@ -2564,4 +2564,75 @@ array conv2d(
{in, wt});
}
array quantized_matmul(
const array& in_x,
const array& w,
const array& scales,
const array& biases,
int groups /* = 128 */,
int width /* = 4 */,
StreamOrDevice s /* = {} */) {
auto x = in_x;
if (w.dtype() != uint32) {
std::ostringstream msg;
msg << "[quantized_matmul] The weight matrix should be uint32 "
<< "but received" << w.dtype();
throw std::invalid_argument(msg.str());
}
if (w.ndim() != 2) {
std::ostringstream msg;
msg << "[quantized_matmul] Batched quantized matmul is not supported for now "
<< "received w with shape " << w.shape();
throw std::invalid_argument(msg.str());
}
// Keep x's batch dimensions to reshape it back after the matmul
auto original_shape = x.shape();
int x_inner_dims = original_shape.back();
original_shape.pop_back();
// Reshape x into a matrix if it isn't already one
if (x.ndim() != 2) {
x = reshape(x, {-1, x_inner_dims}, s);
}
int w_inner_dims = w.shape(0) * (32 / width);
if (w_inner_dims != x_inner_dims) {
std::ostringstream msg;
msg << "[quantized_matmul] Last dimension of first input with "
<< "shape (..., " << x_inner_dims
<< ") does not match the expanded first "
<< "dimension of the quantized matrix " << w_inner_dims
<< ", computed from shape " << w.shape() << " with groups=" << groups
<< " and width=" << width;
throw std::invalid_argument(msg.str());
}
int n_groups = x_inner_dims / groups;
if (scales.shape(-1) != n_groups || biases.shape(-1) != n_groups) {
std::ostringstream msg;
msg << "[quantized_matmul] Scales and biases provided do not match the "
<< "quantization arguments (groups=" << groups << ", width=" << width
<< "). Expected shapes (" << w.shape(1) << ", " << x_inner_dims / groups
<< "), but got scales.shape=" << scales.shape()
<< " and biases.shape=" << biases.shape();
throw std::invalid_argument(msg.str());
}
auto out = array(
{x.shape(0), w.shape(1)},
x.dtype(),
std::make_unique<QuantizedMatmul>(to_stream(s), groups, width),
{x, w, scales, biases});
// If needed reshape x to the original batch shape
if (original_shape.size() != 1) {
original_shape.push_back(w.shape(1));
out = reshape(out, original_shape, s);
}
return out;
}
} // namespace mlx::core

View File

@ -1028,4 +1028,14 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
/** Load array from file in .npy format */
array load(const std::string& file, StreamOrDevice s = {});
/** Quantized matmul multiplies x with a quantized matrix w*/
array quantized_matmul(
const array& x,
const array& w,
const array& scales,
const array& biases,
int groups = 128,
int width = 4,
StreamOrDevice s = {});
} // namespace mlx::core

View File

@ -1696,6 +1696,31 @@ std::pair<array, int> Power::vmap(
return {power(a, b, stream()), to_ax};
}
std::pair<array, int> QuantizedMatmul::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("QuantizedMatmul::vmap NYI");
}
std::vector<array> QuantizedMatmul::vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) {
throw std::runtime_error("QuantizedMatmul::vjp NYI");
}
array QuantizedMatmul::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
throw std::runtime_error("QuantizedMatmul::vjp NYI");
}
bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
const QuantizedMatmul& qm_other = static_cast<const QuantizedMatmul&>(other);
return groups_ == qm_other.groups_ && width_ == qm_other.width_;
}
std::pair<array, int> RandomBits::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {

View File

@ -1110,6 +1110,29 @@ class Power : public Primitive {
void eval(const std::vector<array>& inputs, array& out);
};
class QuantizedMatmul : public Primitive {
public:
explicit QuantizedMatmul(Stream stream, int groups, int width)
: Primitive(stream), groups_(groups), width_(width){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
DEFINE_GRADS()
DEFINE_PRINT(QuantizedMatmul)
bool is_equivalent(const Primitive& other) const override;
private:
int groups_;
int width_;
void eval(const std::vector<array>& inputs, array& out);
};
class RandomBits : public Primitive {
public:
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)

View File

@ -2977,4 +2977,36 @@ void init_ops(py::module_& m) {
Returns:
result (array): An array of the same type as ``a`` rounded to the given number of decimals.
)pbdoc");
m.def(
"quantized_matmul",
&quantized_matmul,
"x"_a,
"w"_a,
py::pos_only(),
"scales"_a,
"biases"_a,
"groups"_a = 128,
"width"_a = 4,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
quantized_matmul(x: array, w: array, scales: array, biases: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``groups`` of
elements. Each element in ``w`` takes ``width`` bits and is packed in an
unsigned 32 bit integer.
Args:
x (array): Input array
w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``groups`` elements of ``w``
biases (array): The biases to use per ``groups`` elements of ``w``
groups (int): The size of the group in ``w`` that shares a scale and
bias. (default: 128)
width (int): The bitwidth of the elements in ``w``. (default: 4)
Returns:
result (array): The result of the multiplication of ``x`` with ``w``.
)pbdoc");
}

View File

@ -0,0 +1,112 @@
# Copyright © 2023 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
def select_bits(w, width, start):
shift_left = 32 - (start + width)
shift_right = shift_left + start
return (w * (2**shift_left)) // (2**shift_right)
def dequantize(w, scales, biases, width):
w_full = mx.concatenate(
[select_bits(w, width, i)[..., None] for i in range(0, 32, width)], axis=-1
)
w_full = w_full.reshape(len(w), scales.shape[-1], -1)
w_full = scales[..., None] * w_full + biases[..., None]
w_full = w_full.reshape(len(w), -1)
return w_full
def quantize(w, width, groups):
w = w.reshape(len(w), -1, groups)
w_max = w.max(-1, keepdims=True)
w_min = w.min(-1, keepdims=True)
delta = (w_max - w_min) / (2**width - 1)
w_int = mx.round((w - w_min) / delta).astype(mx.uint32)
scales = delta.squeeze(-1)
biases = w_min.squeeze(-1)
shifts = mx.array([2**i for i in range(0, 32, width)], dtype=mx.uint32)
w_int = w_int.reshape(len(w), -1, 32 // width)
w_int = w_int * shifts[None, None]
packed_w = w_int.sum(-1)
return packed_w, scales, biases
class TestQuantized(mlx_tests.MLXTestCase):
def test_quantize_dequantize(self):
w = mx.random.normal(shape=(128, 128))
w_q, scales, biases = quantize(w, 4, 64)
w_hat = dequantize(w_q, scales, biases, 4)
w_hat2 = dequantize(*quantize(w_hat, 4, 64), 4)
self.assertLess((w_hat - w_hat2).abs().max(), 1e-6)
def test_qmm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
for groups in [128, 64]:
for width in [2, 4, 8]:
for M in [8, 32, 33, 64]:
for N in [512, 1024]:
for K in [512, 1024]:
with self.subTest(
shape=(M, N, K), groups=groups, width=width
):
x = mx.random.normal(shape=(M, K), key=k1)
w = mx.random.normal(shape=(N, K), key=k2)
w_q, scales, biases = quantize(w, width, groups)
w_hat = dequantize(w_q, scales, biases, width)
y_q = mx.quantized_matmul(
x, w_q.T, scales, biases, width=width, groups=groups
)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 0.1)
def test_qmm_shapes(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
groups = 64
width = 4
w = mx.random.normal(shape=(32, 128), key=k2)
w_q, scales, biases = quantize(w, width, groups)
w_hat = dequantize(w_q, scales, biases, width)
for s in [(3, 128), (2, 1, 7, 128)]:
x = mx.random.normal(shape=(3, 128), key=k1)
y_q = mx.quantized_matmul(
x, w_q.T, scales, biases, width=width, groups=groups
)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 0.1)
def test_qmv(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
for groups in [128, 64]:
for width in [2, 4, 8]:
for M in [512, 1024]:
for N in [512, 1024]:
# with self.subTest(shape=(M, N), groups=groups, width=width):
x = mx.random.normal(shape=(1, N), key=k1)
w = mx.random.normal(shape=(M, N), key=k2)
w_q, scales, biases = quantize(w, width, groups)
w_hat = dequantize(w_q, scales, biases, width)
y_q = mx.quantized_matmul(
x, w_q.T, scales, biases, width=width, groups=groups
)
y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 0.1)
if __name__ == "__main__":
unittest.main()