Support for quantized matmul with w and w^T (#349)

* Add the metal qvm implementation
* Add qmm_n
* Add gradient wrt to input for quantized_matmul
This commit is contained in:
Angelos Katharopoulos 2024-01-03 14:22:36 -08:00 committed by GitHub
parent d7ac050f4b
commit e7f5059fe4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 718 additions and 193 deletions

View File

@ -4,6 +4,7 @@ import argparse
import math import math
import os import os
import time import time
from functools import partial
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -59,15 +60,23 @@ def matmul(x, y):
mx.eval(ys) mx.eval(ys)
def quant_matmul(x, w, s, b): def _quant_matmul(x, w, s, b, group_size, bits):
groups = x.shape[-1] // s.shape[-1]
width = 32 // (x.shape[-1] // w.shape[0])
ys = [] ys = []
for i in range(10): for i in range(10):
ys.append(mx.quantized_matmul(x, w, s, b, groups=groups, width=width)) ys.append(mx.quantized_matmul(x, w, s, b, group_size=group_size, bits=bits))
mx.eval(ys) mx.eval(ys)
quant_matmul = {
"quant_matmul_64_2": partial(_quant_matmul, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, group_size=64, bits=8),
"quant_matmul_128_2": partial(_quant_matmul, group_size=128, bits=2),
"quant_matmul_128_4": partial(_quant_matmul, group_size=128, bits=4),
"quant_matmul_128_8": partial(_quant_matmul, group_size=128, bits=8),
}
def conv1d(x, y): def conv1d(x, y):
ys = [] ys = []
for i in range(10): for i in range(10):
@ -356,8 +365,8 @@ if __name__ == "__main__":
elif args.benchmark == "matmul": elif args.benchmark == "matmul":
print(bench(matmul, *xs)) print(bench(matmul, *xs))
elif args.benchmark == "quant_matmul": elif args.benchmark.startswith("quant_matmul"):
print(bench(quant_matmul, *xs)) print(bench(quant_matmul[args.benchmark], *xs))
elif args.benchmark == "linear": elif args.benchmark == "linear":
print(bench(linear, *xs)) print(bench(linear, *xs))

View File

@ -76,20 +76,16 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& scales = inputs[2]; auto& scales = inputs[2];
auto& biases = inputs[3]; auto& biases = inputs[3];
if (w.strides()[0] != 1) { bool condition =
throw std::runtime_error("The quantized weight should be transposed"); (transpose_ && x.flags().row_contiguous && w.flags().row_contiguous &&
} scales.flags().row_contiguous && biases.flags().row_contiguous &&
x.dtype() == float32 && bits_ == 4 && group_size_ == 64);
if (!x.flags().row_contiguous || !scales.flags().row_contiguous || if (condition) {
!biases.flags().row_contiguous) {
throw std::runtime_error("x, scales and biases should be row contiguous.");
}
if (x.dtype() == float32 && bits_ == 4 && group_size_ == 64) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
int K = x.shape(-1); int K = x.shape(-1);
int M = x.size() / K; int M = x.size() / K;
int N = w.shape(1); int N = out.shape(-1);
_qmm_t_4_64( _qmm_t_4_64(
out.data<float>(), out.data<float>(),
x.data<float>(), x.data<float>(),

View File

@ -1,13 +1,62 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <cassert> #include <cassert>
#include <iostream>
#include "mlx/backend/metal/copy.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
namespace { namespace {
template <typename T, int bits, int group_size>
void _qmm(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K) {
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
const int Ng = N / group_size;
const int Nw = N / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const T* scales_local = scales;
const T* biases_local = biases;
std::fill(result, result + N, 0);
for (int k = 0; k < K; k++) {
T* result_local = result;
T xi = *x++;
for (int n = 0; n < N; n += group_size) {
T scale = *scales_local++;
T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) {
uint32_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
wi >>= bits;
}
}
}
}
result += N;
}
}
template <typename T, int bits, int group_size> template <typename T, int bits, int group_size>
void _qmm_t( void _qmm_t(
T* result, T* result,
@ -55,7 +104,7 @@ void _qmm_t(
} }
template <typename T> template <typename T>
void _qmm_t_dispatch_typed( void _qmm_dispatch_typed(
T* result, T* result,
const T* x, const T* x,
const uint32_t* w, const uint32_t* w,
@ -65,30 +114,55 @@ void _qmm_t_dispatch_typed(
int N, int N,
int K, int K,
int group_size, int group_size,
int bits) { int bits,
bool transposed_w) {
switch (bits) { switch (bits) {
case 2: { case 2: {
switch (group_size) { switch (group_size) {
case 64: case 64:
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K); if (transposed_w) {
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K);
}
case 128: case 128:
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K); if (transposed_w) {
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K);
}
} }
} }
case 4: { case 4: {
switch (group_size) { switch (group_size) {
case 64: case 64:
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K); if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
}
case 128: case 128:
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K); if (transposed_w) {
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
}
} }
} }
case 8: { case 8: {
switch (group_size) { switch (group_size) {
case 64: case 64:
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K); if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
}
case 128: case 128:
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K); if (transposed_w) {
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
}
} }
} }
} }
@ -100,21 +174,22 @@ void _qmm_t_dispatch_typed(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
void _qmm_t_dispatch( void _qmm_dispatch(
array out, array out,
const array& x, const array& x,
const array& w, const array& w,
const array& scales, const array& scales,
const array& biases, const array& biases,
int bits, int bits,
int group_size) { int group_size,
bool transposed_w) {
int K = x.shape(-1); int K = x.shape(-1);
int M = x.size() / K; int M = x.size() / K;
int N = w.shape(1); int N = out.shape(-1);
switch (x.dtype()) { switch (x.dtype()) {
case float32: case float32:
_qmm_t_dispatch_typed<float>( _qmm_dispatch_typed<float>(
out.data<float>(), out.data<float>(),
x.data<float>(), x.data<float>(),
w.data<uint32_t>(), w.data<uint32_t>(),
@ -124,10 +199,11 @@ void _qmm_t_dispatch(
N, N,
K, K,
bits, bits,
group_size); group_size,
transposed_w);
break; break;
case float16: case float16:
_qmm_t_dispatch_typed<float16_t>( _qmm_dispatch_typed<float16_t>(
out.data<float16_t>(), out.data<float16_t>(),
x.data<float16_t>(), x.data<float16_t>(),
w.data<uint32_t>(), w.data<uint32_t>(),
@ -137,10 +213,11 @@ void _qmm_t_dispatch(
N, N,
K, K,
bits, bits,
group_size); group_size,
transposed_w);
break; break;
case bfloat16: case bfloat16:
_qmm_t_dispatch_typed<bfloat16_t>( _qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>(), out.data<bfloat16_t>(),
x.data<bfloat16_t>(), x.data<bfloat16_t>(),
w.data<uint32_t>(), w.data<uint32_t>(),
@ -150,7 +227,8 @@ void _qmm_t_dispatch(
N, N,
K, K,
bits, bits,
group_size); group_size,
transposed_w);
break; break;
default: default:
throw std::invalid_argument( throw std::invalid_argument(
@ -163,22 +241,28 @@ void _qmm_t_dispatch(
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) { void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4); assert(inputs.size() == 4);
auto& x = inputs[0]; auto& x_pre = inputs[0];
auto& w = inputs[1]; auto& w_pre = inputs[1];
auto& scales = inputs[2]; auto& scales_pre = inputs[2];
auto& biases = inputs[3]; auto& biases_pre = inputs[3];
if (w.strides()[0] != 1) { auto ensure_row_contiguous = [](const array& arr) {
throw std::runtime_error("The quantized weight should be transposed"); if (arr.flags().row_contiguous) {
} return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
if (!x.flags().row_contiguous || !scales.flags().row_contiguous || auto x = ensure_row_contiguous(x_pre);
!biases.flags().row_contiguous) { auto w = ensure_row_contiguous(w_pre);
throw std::runtime_error("x, scales and biases should be row contiguous."); auto scales = ensure_row_contiguous(scales_pre);
} auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_t_dispatch(out, x, w, scales, biases, group_size_, bits_); _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -104,6 +104,108 @@ template <typename T, const int BM, const int BN, const int group_size, const in
} }
template <typename T, const int BM, const int BN, const int group_size, const int bits>
[[kernel]] void qvm(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BM == SIMD_SIZE, "qvm expects BM to be equal to SIMD_SIZE");
static_assert(BN == BM, "qvm expects a block size of 32x32");
(void)lid;
constexpr int bitmask = (1 << bits) - 1;
constexpr int el_per_int = 32 / bits;
constexpr int colgroup = BN * el_per_int;
constexpr int groups_per_block = colgroup / group_size;
threadgroup T scales_block[BM * groups_per_block];
threadgroup T biases_block[BM * groups_per_block];
threadgroup T x_block[BM];
thread uint32_t w_local;
thread T result[el_per_int] = {0};
thread T scale = 1;
thread T bias = 0;
thread T x_local = 0;
// Adjust positions
const int out_vec_size_w = out_vec_size / el_per_int;
const int out_vec_size_g = out_vec_size / group_size;
int out_col = (tid.y * BN + simd_gid) * el_per_int;
w += out_col / el_per_int;
scales += out_col / group_size;
biases += out_col / group_size;
x += tid.z * in_vec_size;
y += tid.z * out_vec_size + out_col;
if (out_col >= out_vec_size) {
return;
}
// Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=BM) {
// Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid == 0) {
x_block[simd_lid] = x[simd_lid + i];
}
// Load the scales and biases to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid == 0) {
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
scales_block[simd_lid * groups_per_block + j] = scales[(i + simd_lid) * out_vec_size_g + j];
}
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
biases_block[simd_lid * groups_per_block + j] = biases[(i + simd_lid) * out_vec_size_g + j];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load in_vec, scale, bias to registers
x_local = x_block[simd_lid];
scale = scales_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
// Load the matrix elements
w_local = w[(i + simd_lid) * out_vec_size_w];
// Do all the work.
#pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) {
result[k] += (scale * static_cast<T>(w_local & bitmask) + bias) * x_local;
w_local >>= bits;
}
}
// Accumulate in the simdgroup
#pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) {
result[k] = simd_sum(result[k]);
}
// Store the result
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int k=0; k<el_per_int; k++) {
y[k] = result[k];
}
}
}
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits> template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
[[kernel]] void qmm_t( [[kernel]] void qmm_t(
const device T* x [[buffer(0)]], const device T* x [[buffer(0)]],
@ -133,8 +235,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
constexpr int groups_per_simd = BN / (WM * WN); constexpr int groups_per_simd = BN / (WM * WN);
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN); constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
// Using the kernel just as a type to instantiate the appropriate BlockMMA // Instantiate the appropriate BlockMMA and Loader
// and constexpr size calculations
using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, true>; using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, true>;
using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>; using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>;
@ -231,8 +332,133 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
} }
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
[[kernel]] void qmm_n(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& M [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& K [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
const uint lidy = lid / SIMD_SIZE;
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int bitmask = (1 << bits) - 1;
constexpr int el_per_int = 32 / bits;
constexpr int groups_per_block = (BN / group_size > 0) ? (BN / group_size) : 1;
constexpr int groups_per_simd = BK / (WM * WN);
constexpr int w_els_per_thread = (BK * BN / el_per_int) / (SIMD_SIZE * WM * WN);
// Instantiate the appropriate BlockMMA and Loader
using mma_t = BlockMMA<T, BM, BN, BK, WM, WN, false, false>;
using loader_x_t = BlockLoader<T, BM, BK, BK, 4, WM * WN * SIMD_SIZE, false, true, 0>;
threadgroup T scales_block[BK * groups_per_block];
threadgroup T biases_block[BK * groups_per_block];
threadgroup T Xs[BM * BK];
threadgroup T Ws[BK * BN];
// Set the block
const int N_w = N / el_per_int;
const int N_g = N / group_size;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
x += y_row * K;
w += y_col / el_per_int;
scales += y_col / group_size;
biases += y_col / group_size;
y += y_row * N + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
for (int k=0; k<K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the x tile
if (num_els < BM) {
loader_x.load_safe(short2(BK, num_els));
} else {
loader_x.load_unsafe();
}
// Load the scale and bias
if (simd_lid == 0) {
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd;
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd;
const device T *scales_local = scales + lidy * groups_per_simd * N_g;
const device T *biases_local = biases + lidy * groups_per_simd * N_g;
#pragma clang loop unroll(full)
for (int gs=0; gs<groups_per_simd; gs++) {
#pragma clang loop unroll(full)
for (int gc=0; gc<groups_per_block; gc++) {
scales_block_local[gc] = scales_local[gc];
biases_block_local[gc] = biases_local[gc];
}
scales_block_local += groups_per_block;
scales_local += N_g;
biases_block_local += groups_per_block;
biases_local += N_g;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load the w tile
{
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);
int offset_col = offset % (BN / el_per_int);
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(Xs, Ws);
// Prepare for next iteration
loader_x.next();
w += BK * N_w;
scales += BK * N_g;
biases += BK * N_g;
}
// Store results to device memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (num_els < BM) {
mma_op.store_result_safe(y, N, short2(BN, num_els));
} else {
mma_op.store_result(y, N);
}
}
#define instantiate_qmv(name, itype, group_size, bits) \ #define instantiate_qmv(name, itype, group_size, bits) \
template [[host_name("qmv_n_" #name "_gs_" #group_size "_b_" #bits)]] \ template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmv<itype, 32, 32, group_size, bits>( \ [[kernel]] void qmv<itype, 32, 32, group_size, bits>( \
const device uint32_t* w [[buffer(0)]], \ const device uint32_t* w [[buffer(0)]], \
const device itype* scales [[buffer(1)]], \ const device itype* scales [[buffer(1)]], \
@ -258,6 +484,33 @@ instantiate_qmv_types( 64, 2)
instantiate_qmv_types( 64, 4) instantiate_qmv_types( 64, 4)
instantiate_qmv_types( 64, 8) instantiate_qmv_types( 64, 8)
#define instantiate_qvm(name, itype, group_size, bits) \
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qvm<itype, 32, 32, group_size, bits>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
const device itype* biases [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& in_vec_size [[buffer(5)]], \
const constant int& out_vec_size [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qvm_types(group_size, bits) \
instantiate_qvm(float32, float, group_size, bits) \
instantiate_qvm(float16, half, group_size, bits) \
instantiate_qvm(bfloat16, bfloat16_t, group_size, bits)
instantiate_qvm_types(128, 2)
instantiate_qvm_types(128, 4)
instantiate_qvm_types(128, 8)
instantiate_qvm_types( 64, 2)
instantiate_qvm_types( 64, 4)
instantiate_qvm_types( 64, 8)
#define instantiate_qmm_t(name, itype, group_size, bits) \ #define instantiate_qmm_t(name, itype, group_size, bits) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \ template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits>( \ [[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits>( \
@ -285,3 +538,31 @@ instantiate_qmm_t_types(128, 8)
instantiate_qmm_t_types( 64, 2) instantiate_qmm_t_types( 64, 2)
instantiate_qmm_t_types( 64, 4) instantiate_qmm_t_types( 64, 4)
instantiate_qmm_t_types( 64, 8) instantiate_qmm_t_types( 64, 8)
#define instantiate_qmm_n(name, itype, group_size, bits) \
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmm_n<itype, 32, 32, 64, group_size, bits>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
const device itype* biases [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& M [[buffer(5)]], \
const constant int& N [[buffer(6)]], \
const constant int& K [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmm_n_types(group_size, bits) \
instantiate_qmm_n(float32, float, group_size, bits) \
instantiate_qmm_n(float16, half, group_size, bits) \
instantiate_qmm_n(bfloat16, bfloat16_t, group_size, bits)
instantiate_qmm_n_types(128, 2)
instantiate_qmm_n_types(128, 4)
instantiate_qmm_n_types(128, 8)
instantiate_qmm_n_types( 64, 2)
instantiate_qmm_n_types( 64, 4)
instantiate_qmm_n_types( 64, 8)

View File

@ -1,7 +1,6 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <cassert> #include <cassert>
#include <iostream>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
@ -23,97 +22,147 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& biases_pre = inputs[3]; auto& biases_pre = inputs[3];
std::vector<array> copies; std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr) { auto ensure_row_contiguous = [&copies, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2]; if (arr.flags().row_contiguous) {
auto sty = arr.strides()[arr.ndim() - 1]; return arr;
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s); copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy); copies.push_back(arr_copy);
size_t stx = arr.shape(-1); return arr_copy;
return std::make_tuple(false, stx, arr_copy);
} }
}; };
auto [x_transposed, x_cols, x] = check_transpose(x_pre); auto x = ensure_row_contiguous(x_pre);
auto [w_transposed, w_cols, w] = check_transpose(w_pre); auto w = ensure_row_contiguous(w_pre);
auto [scales_transposed, scales_cols, scales] = check_transpose(scales_pre); auto scales = ensure_row_contiguous(scales_pre);
auto [biases_transposed, biases_cols, biases] = check_transpose(biases_pre); auto biases = ensure_row_contiguous(biases_pre);
if (!w_transposed) {
throw std::runtime_error("The quantized weight should be transposed.");
}
if (x_transposed || scales_transposed || biases_transposed) {
throw std::runtime_error("x, scales and biases should be row contiguous.");
}
int D = x.shape(-1); int D = x.shape(-1);
int B = x.size() / D; int B = x.size() / D;
int O = out.shape(-1);
if (transpose_) {
// Route to the qmv kernel
if (B < 6) {
std::ostringstream kname;
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
// Route to the qmv kernel // Encode and dispatch kernel
if (B == 1) { auto compute_encoder = d.get_command_encoder(s.index);
std::ostringstream kname; auto kernel = d.get_kernel(kname.str());
kname << "qmv_" << (w_transposed ? "n_" : "t_") << type_to_name(out) compute_encoder->setComputePipelineState(kernel);
<< "_gs_" << group_size_ << "_b_" << bits_;
// Encode and dispatch kernel int bo = 32;
auto compute_encoder = d.get_command_encoder(s.index); int bd = 32;
auto kernel = d.get_kernel(kname.str()); MTL::Size group_dims = MTL::Size(bd, bo, 1);
compute_encoder->setComputePipelineState(kernel); MTL::Size grid_dims = MTL::Size(1, O / bo, B);
int O = w.size() / w_cols; set_array_buffer(compute_encoder, w, 0);
set_array_buffer(compute_encoder, scales, 1);
set_array_buffer(compute_encoder, biases, 2);
set_array_buffer(compute_encoder, x, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
int bo = 32; compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
int bd = 32; }
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
set_array_buffer(compute_encoder, w, 0); // Route to the qmm_t kernel
set_array_buffer(compute_encoder, scales, 1); else {
set_array_buffer(compute_encoder, biases, 2); std::ostringstream kname;
set_array_buffer(compute_encoder, x, 3); kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
set_array_buffer(compute_encoder, out, 4); << bits_;
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims); // Encode and dispatch kernel
} auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Route to the qmm kernel int wn = 2;
else { int wm = 2;
std::ostringstream kname; int bm = 32;
kname << "qmm_" << (w_transposed ? "t_" : "n_") << type_to_name(out) int bn = 32;
<< "_gs_" << group_size_ << "_b_" << bits_; int bk = 64;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
// Encode and dispatch kernel set_array_buffer(compute_encoder, x, 0);
auto compute_encoder = d.get_command_encoder(s.index); set_array_buffer(compute_encoder, w, 1);
auto kernel = d.get_kernel(kname.str()); set_array_buffer(compute_encoder, scales, 2);
compute_encoder->setComputePipelineState(kernel); set_array_buffer(compute_encoder, biases, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&B, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
int O = w.size() / w_cols; compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
} else {
// Route to the qvm kernel
if (B < 4) {
std::ostringstream kname;
kname << "qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
int wn = 2; // Encode and dispatch kernel
int wm = 2; auto compute_encoder = d.get_command_encoder(s.index);
int bm = 32; auto kernel = d.get_kernel(kname.str());
int bn = 32; compute_encoder->setComputePipelineState(kernel);
int bk = 64;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
set_array_buffer(compute_encoder, x, 0); int bo = 32;
set_array_buffer(compute_encoder, w, 1); int bd = 32;
set_array_buffer(compute_encoder, scales, 2); MTL::Size group_dims = MTL::Size(bd, bo, 1);
set_array_buffer(compute_encoder, biases, 3); MTL::Size grid_dims = MTL::Size(1, (w.shape(1) + bo - 1) / bo, B);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&B, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims); set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, scales, 2);
set_array_buffer(compute_encoder, biases, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmm_n kernel
else {
std::ostringstream kname;
kname << "qmm_n_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 64;
int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
if ((O % bn) != 0) {
std::ostringstream msg;
msg << "[quantized_matmul] The output size should be divisible by "
<< bn << " but received " << O << ".";
throw std::runtime_error(msg.str());
}
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, scales, 2);
set_array_buffer(compute_encoder, biases, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&B, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
} }
d.get_command_buffer(s.index)->addCompletedHandler( d.get_command_buffer(s.index)->addCompletedHandler(

View File

@ -2618,10 +2618,11 @@ array quantized_matmul(
const array& w, const array& w,
const array& scales, const array& scales,
const array& biases, const array& biases,
bool transpose /* = true */,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
auto x = in_x; array x = in_x;
if (w.dtype() != uint32) { if (w.dtype() != uint32) {
std::ostringstream msg; std::ostringstream msg;
@ -2646,39 +2647,52 @@ array quantized_matmul(
x = reshape(x, {-1, x_inner_dims}, s); x = reshape(x, {-1, x_inner_dims}, s);
} }
int w_inner_dims = w.shape(0) * (32 / bits); if (scales.ndim() != 2 || scales.shape() != biases.shape()) {
if (w_inner_dims != x_inner_dims) {
std::ostringstream msg; std::ostringstream msg;
msg << "[quantized_matmul] Last dimension of first input with " msg << "[quantized_matmul] Scales and biases should have the same 2D shape. "
<< "shape (..., " << x_inner_dims << "Received scales with shape " << scales.shape()
<< ") does not match the expanded first " << " and biases with " << biases.shape();
<< "dimension of the quantized matrix " << w_inner_dims throw std::invalid_argument(msg.str());
<< ", computed from shape " << w.shape() }
if (w.shape(1) * 32 / bits != scales.shape(1) * group_size) {
std::ostringstream msg;
msg << "[quantized_matmul] The shapes of the weight and scales are "
<< "incompatible based on bits and group_size. w.shape() == "
<< w.shape() << " and scales.shape() == " << scales.shape()
<< " with group_size=" << group_size << " and bits=" << bits; << " with group_size=" << group_size << " and bits=" << bits;
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
int n_groups = x_inner_dims / group_size; // Calculate the expanded w's dims
if (scales.shape(-1) != n_groups || biases.shape(-1) != n_groups) { int w_inner_dims = (transpose) ? w.shape(1) * 32 / bits : w.shape(0);
int w_outer_dims = (transpose) ? w.shape(0) : w.shape(1) * 32 / bits;
if (w_inner_dims != x_inner_dims) {
std::ostringstream msg; std::ostringstream msg;
msg << "[quantized_matmul] Scales and biases provided do not match the " msg << "[quantized_matmul] Last dimension of first input with "
<< "quantization arguments (group_size=" << group_size << "shape (..., " << x_inner_dims << ") does not match "
<< ", bits=" << bits << "). Expected shapes (" << w.shape(1) << ", " << "the expanded quantized matrix (" << w_inner_dims << ", "
<< x_inner_dims / group_size << w_outer_dims << ") computed from shape " << w.shape()
<< "), but got scales.shape=" << scales.shape() << " with group_size=" << group_size << ", bits=" << bits
<< " and biases.shape=" << biases.shape(); << " and transpose=" << std::boolalpha << transpose;
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
auto dtype = result_type({x, scales, biases});
auto out = array( auto out = array(
{x.shape(0), w.shape(1)}, {x.shape(0), w_outer_dims},
x.dtype(), dtype,
std::make_unique<QuantizedMatmul>(to_stream(s), group_size, bits), std::make_unique<QuantizedMatmul>(
{x, w, scales, biases}); to_stream(s), group_size, bits, transpose),
{astype(x, dtype, s),
w,
astype(scales, dtype, s),
astype(biases, dtype, s)});
// If needed reshape x to the original batch shape // If needed reshape x to the original batch shape
if (original_shape.size() != 1) { if (original_shape.size() != 1) {
original_shape.push_back(w.shape(1)); original_shape.push_back(w_outer_dims);
out = reshape(out, original_shape, s); out = reshape(out, original_shape, s);
} }

View File

@ -1041,6 +1041,7 @@ array quantized_matmul(
const array& w, const array& w,
const array& scales, const array& scales,
const array& biases, const array& biases,
bool transpose = true,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
StreamOrDevice s = {}); StreamOrDevice s = {});

View File

@ -1706,14 +1706,37 @@ std::vector<array> QuantizedMatmul::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const array& cotan, const array& cotan,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
throw std::runtime_error("QuantizedMatmul::vjp NYI"); std::vector<array> vjps;
// We rely on the fact that w is always 2D so transpose is simple
for (auto arg : argnums) {
// gradient wrt to x
if (arg == 0) {
vjps.push_back(quantized_matmul(
cotan,
primals[1],
primals[2],
primals[3],
!transpose_,
group_size_,
bits_,
stream()));
}
// gradient wrt to w_q, scales or biases
else {
throw std::runtime_error(
"QuantizedMatmul::vjp no gradient wrt the quantized matrix yet.");
}
}
return vjps;
} }
array QuantizedMatmul::jvp( array QuantizedMatmul::jvp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
throw std::runtime_error("QuantizedMatmul::vjp NYI"); throw std::runtime_error("QuantizedMatmul::jvp NYI");
} }
bool QuantizedMatmul::is_equivalent(const Primitive& other) const { bool QuantizedMatmul::is_equivalent(const Primitive& other) const {

View File

@ -1112,8 +1112,15 @@ class Power : public Primitive {
class QuantizedMatmul : public Primitive { class QuantizedMatmul : public Primitive {
public: public:
explicit QuantizedMatmul(Stream stream, int group_size, int bits) explicit QuantizedMatmul(
: Primitive(stream), group_size_(group_size), bits_(bits){}; Stream stream,
int group_size,
int bits,
bool transpose)
: Primitive(stream),
group_size_(group_size),
bits_(bits),
transpose_(transpose){};
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1129,6 +1136,7 @@ class QuantizedMatmul : public Primitive {
private: private:
int group_size_; int group_size_;
int bits_; int bits_;
bool transpose_;
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };

View File

@ -81,9 +81,10 @@ class QuantizedLinear(Module):
def __call__(self, x): def __call__(self, x):
x = mx.quantized_matmul( x = mx.quantized_matmul(
x, x,
self.weight.T, self.weight,
scales=self.scales, scales=self.scales,
biases=self.biases, biases=self.biases,
transpose=True,
group_size=self.group_size, group_size=self.group_size,
bits=self.bits, bits=self.bits,
) )

View File

@ -3072,12 +3072,13 @@ void init_ops(py::module_& m) {
py::pos_only(), py::pos_only(),
"scales"_a, "scales"_a,
"biases"_a, "biases"_a,
"transpose"_a = true,
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
py::kw_only(), py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
quantized_matmul(x: array, w: array, scales: array, biases: array, /, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
Perform the matrix multiplication with the quantized matrix ``w``. The Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``group_size`` of quantization uses one floating point scale and bias per ``group_size`` of
@ -3089,10 +3090,13 @@ void init_ops(py::module_& m) {
w (array): Quantized matrix packed in unsigned integers w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``group_size`` elements of ``w`` scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w`` biases (array): The biases to use per ``group_size`` elements of ``w``
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. (default: ``True``)
group_size (int, optional): The size of the group in ``w`` that group_size (int, optional): The size of the group in ``w`` that
shares a scale and bias. (default: 64) shares a scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element in bits (int, optional): The number of bits occupied by each element in
``w``. (default: 4) ``w``. (default: ``4``)
Returns: Returns:
result (array): The result of the multiplication of ``x`` with ``w``. result (array): The result of the multiplication of ``x`` with ``w``.
@ -3146,9 +3150,9 @@ void init_ops(py::module_& m) {
Args: Args:
w (array): Matrix to be quantized w (array): Matrix to be quantized
group_size (int, optional): The size of the group in ``w`` that shares a group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: 64) scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element of bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. (default: 4) ``w`` in the returned quantized matrix. (default: ``4``)
Returns: Returns:
(tuple): A tuple containing (tuple): A tuple containing
@ -3187,9 +3191,9 @@ void init_ops(py::module_& m) {
scales (array): The scales to use per ``group_size`` elements of ``w`` scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w`` biases (array): The biases to use per ``group_size`` elements of ``w``
group_size (int, optional): The size of the group in ``w`` that shares a group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: 64) scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element in bits (int, optional): The number of bits occupied by each element in
``w``. (default: 4) ``w``. (default: ``4``)
Returns: Returns:
result (array): The dequantized version of ``w`` result (array): The dequantized version of ``w``

View File

@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023 Apple Inc.
import unittest import unittest
from itertools import product
import mlx.core as mx import mlx.core as mx
import mlx_tests import mlx_tests
@ -19,62 +20,116 @@ class TestQuantized(mlx_tests.MLXTestCase):
def test_qmm(self): def test_qmm(self):
key = mx.random.key(0) key = mx.random.key(0)
k1, k2 = mx.random.split(key) k1, k2 = mx.random.split(key)
for group_size in [128, 64]: tests = product(
for bits in [2, 4, 8]: [128, 64], # group_size
for M in [8, 32, 33, 64]: [2, 4, 8], # bits
for N in [512, 1024]: [8, 32, 33, 64], # M
for K in [512, 1024]: [512, 1024], # N
with self.subTest( [512, 1024], # K
shape=(M, N, K), group_size=group_size, bits=bits [True, False], # transposed
): )
x = mx.random.normal(shape=(M, K), key=k1) for group_size, bits, M, N, K, transposed in tests:
w = mx.random.normal(shape=(N, K), key=k2) with self.subTest(
w_q, scales, biases = mx.quantize(w, group_size, bits) shape=(M, N, K),
w_hat = mx.dequantize( group_size=group_size,
w_q, scales, biases, group_size, bits bits=bits,
) transposed=transposed,
y_q = mx.quantized_matmul( ):
x, w_q.T, scales, biases, group_size, bits x = mx.random.normal(shape=(M, K), key=k1)
) w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
y_hat = x @ w_hat.T w_q, scales, biases = mx.quantize(w, group_size, bits)
self.assertEqual(y_q.shape, y_hat.shape) w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
self.assertLess((y_q - y_hat).abs().max(), 1e-3) y_q = mx.quantized_matmul(
x, w_q, scales, biases, transposed, group_size, bits
)
y_hat = (x @ w_hat.T) if transposed else (x @ w_hat)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qmm_shapes(self): def test_qmm_shapes(self):
key = mx.random.key(0) key = mx.random.key(0)
k1, k2 = mx.random.split(key) k1, k2 = mx.random.split(key)
group_size = 64 group_size = 64
bits = 4 bits = 4
w = mx.random.normal(shape=(32, 128), key=k2) w = mx.random.normal(shape=(32, 256), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits) w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
for s in [(3, 128), (2, 1, 7, 128)]: for s in [(3, 256), (2, 1, 7, 256)]:
x = mx.random.normal(shape=(3, 128), key=k1) x = mx.random.normal(shape=s, key=k1)
y_q = mx.quantized_matmul(x, w_q.T, scales, biases, group_size, bits) y_q = mx.quantized_matmul(x, w_q, scales, biases, True, group_size, bits)
y_hat = x @ w_hat.T y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3) self.assertLess((y_q - y_hat).abs().max(), 1e-3)
w = mx.random.normal(shape=(256, 256), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
for s in [(3, 256), (2, 1, 7, 256)]:
x = mx.random.normal(shape=s, key=k1)
y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qmv(self): def test_qmv(self):
key = mx.random.key(0) key = mx.random.key(0)
k1, k2 = mx.random.split(key) k1, k2 = mx.random.split(key)
for group_size in [128, 64]: tests = product(
for bits in [2, 4, 8]: [128, 64], # group_size
for M in [512, 1024]: [2, 4, 8], # bits
for N in [512, 1024]: [512, 1024], # M
with self.subTest( [512, 1024], # N
shape=(M, N), group_size=group_size, bits=bits )
): for group_size, bits, M, N in tests:
x = mx.random.normal(shape=(1, N), key=k1) with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
w = mx.random.normal(shape=(M, N), key=k2) x = mx.random.normal(shape=(1, N), key=k1)
w_q, scales, biases = mx.quantize(w, group_size, bits) w = mx.random.normal(shape=(M, N), key=k2)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) w_q, scales, biases = mx.quantize(w, group_size, bits)
y_q = mx.quantized_matmul( w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
x, w_q.T, scales, biases, group_size, bits y_q = mx.quantized_matmul(
) x, w_q, scales, biases, True, group_size, bits
y_hat = x @ w_hat.T )
self.assertEqual(y_q.shape, y_hat.shape) y_hat = x @ w_hat.T
self.assertLess((y_q - y_hat).abs().max(), 1e-3) self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qvm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[128, 64], # group_size
[2, 4, 8], # bits
[512, 1024], # M
[512, 1024], # N
)
for group_size, bits, M, N in tests:
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
x = mx.random.normal(shape=(1, N), key=k1)
w = mx.random.normal(shape=(N, M), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(
x, w_q, scales, biases, False, group_size, bits
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_throw(self):
x = mx.random.normal(shape=(10, 512))
w = mx.random.normal(shape=(32, 512))
w_q, scales, biases = mx.quantize(w)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q.T, scales, biases)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q.T, scales.T, biases)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q, scales, biases, False)
with self.assertRaises(ValueError):
mx.quantized_matmul(x, w_q, scales.T, biases.T)
y = mx.quantized_matmul(x, w_q, scales, biases, True)
mx.eval(y)
if __name__ == "__main__": if __name__ == "__main__":