Compare commits

...

19 Commits

Author SHA1 Message Date
Angelos Katharopoulos
c02e14c264 Add the 3bit packed qmm_t 2024-12-17 22:16:30 -08:00
Angelos Katharopoulos
d75a509234 Add 3bit packed quants 2024-12-17 10:49:13 -08:00
Angelos Katharopoulos
14420949d2 Fix the optional in gather_qmm python binding 2024-12-16 22:14:19 -08:00
Angelos Katharopoulos
4847199ec6 Add the quantization type option to quantizable layers 2024-12-16 22:11:23 -08:00
Angelos Katharopoulos
fb7be036af Add packed_affine_qmm_t 2024-12-16 21:49:14 -08:00
Angelos Katharopoulos
410ccdbed5 Change the argument name to quantization_type 2024-12-16 13:32:20 -08:00
Angelos Katharopoulos
f5da489a3c Add some error reporting 2024-12-16 13:22:05 -08:00
Angelos Katharopoulos
c2e6d58441 Revert the change in packing order 2024-12-16 13:20:01 -08:00
Angelos Katharopoulos
17a1fa2f0b Improve the benchmark 2024-12-14 23:04:29 -08:00
Angelos Katharopoulos
fd161aa31f Change order in weight packing 2024-12-14 22:51:41 -08:00
Angelos Katharopoulos
bf6dc54110 Add the 2 bit vectorized reads 2024-12-14 21:19:02 -08:00
Angelos Katharopoulos
d7ed624502 Vectorized reads 2024-12-14 15:36:34 -08:00
Angelos Katharopoulos
05cb54ae3f Another packing 2024-12-13 23:48:25 -08:00
Angelos Katharopoulos
cb358dbdda Revert "Attempt different packing"
This reverts commit e4b587819c.
2024-12-13 23:23:41 -08:00
Angelos Katharopoulos
e4b587819c Attempt different packing 2024-12-13 18:36:36 -08:00
Angelos Katharopoulos
a06c968f4d Add a small benchmark 2024-12-13 16:29:11 -08:00
Angelos Katharopoulos
651c510940 Working packed qmv 2024-12-13 16:29:11 -08:00
Angelos Katharopoulos
11ec07ff9d Initial python binding 2024-12-13 16:29:11 -08:00
Angelos Katharopoulos
bdd68bd893 Add a quantization type in the ops 2024-12-13 16:29:11 -08:00
14 changed files with 1291 additions and 71 deletions

View File

@@ -0,0 +1,74 @@
import argparse
import math
import mlx.core as mx
from time_utils import time_fn
B = 1024
D = 1024
M = 4 * D
group_size = 64
bits = 4
dtype = mx.float16
loops = 10
def qmm_(x, wq1, wq2, q_type):
for i in range(loops):
x = mx.quantized_matmul(
x,
*wq1,
group_size=group_size,
bits=bits,
quantization_type=q_type,
)
x = mx.quantized_matmul(
x,
*wq2,
group_size=group_size,
bits=bits,
quantization_type=q_type,
)
return x
def affine_qmm(x, wq1, wq2):
return qmm_(x, wq1, wq2, "affine")
def affine_packed_qmm(x, wq1, wq2):
return qmm_(x, wq1, wq2, "affine-packed")
def time_qmm():
mx.random.seed(3)
x = mx.random.normal(shape=(B, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
mx.eval(x, wq1, wq2)
time_fn(affine_qmm, x, wq1, wq2)
def time_packed_qmm():
mx.random.seed(3)
x = mx.random.normal(shape=(B, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
mx.eval(x, wq1, wq2)
time_fn(affine_packed_qmm, x, wq1, wq2)
if __name__ == "__main__":
for b in [2, 4, 8]:
bits = b
print(f"Bits {bits}:")
time_qmm()
time_packed_qmm()

View File

@@ -1248,6 +1248,41 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride;
}
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
const device uint32_t*& w,
const device T*& scales,
device T*& y,
int output_stride,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant int64_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant int64_t* w_strides,
const constant int64_t* s_strides,
uint3 tid [[threadgroup_position_in_grid]]) {
// Set the input/output matrices
uint32_t x_idx = tid.z;
uint32_t w_idx = tid.z;
if (x_batch_ndims == 1) {
x += x_idx * x_strides[0];
} else {
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
}
if (w_batch_ndims == 1) {
w += w_idx * w_strides[0];
scales += w_idx * s_strides[0];
} else {
ulong2 idx = elem_to_loc_broadcast(
w_idx, w_shape, w_strides, s_strides, w_batch_ndims);
w += idx.x;
scales += idx.y;
}
y += tid.z * output_stride;
}
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
@@ -2149,3 +2184,666 @@ template <typename T, const int group_size, const int bits>
}
}
}
template <typename T, typename U, int bits>
inline vec<U, 4> partial_qdot_vec(const thread U* x, vec<uint32_t, 4> w) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
vec<U, 4> accum = 0;
if (bits == 2) {
for (int i = 0; i < 4; i++) {
auto ws = as_type<vec<uint8_t, 4>>(w[i]);
for (int j = 0; j < 4; j++) {
accum[i] +=
(x[4 * j + 0] * (ws[j] & 0x03) + x[4 * j + 1] * (ws[j] & 0x0c) +
x[4 * j + 2] * (ws[j] & 0x30) + x[4 * j + 3] * (ws[j] & 0xc0));
}
}
}
else if (bits == 4) {
for (int i = 0; i < 4; i++) {
auto ws = as_type<vec<uint16_t, 2>>(w[i]);
for (int j = 0; j < 2; j++) {
accum[i] +=
(x[4 * j + 0] * (ws[j] & 0x000f) + x[4 * j + 1] * (ws[j] & 0x00f0) +
x[4 * j + 2] * (ws[j] & 0x0f00) + x[4 * j + 3] * (ws[j] & 0xf000));
}
}
}
else if (bits == 8) {
for (int i = 0; i < 4; i++) {
auto ws = as_type<vec<uint8_t, 4>>(w[i]);
for (int j = 0; j < 4; j++) {
accum[i] += x[j] * ws[j];
}
}
}
return accum;
}
template <typename T, int group_size, int bits>
METAL_FUNC void affine_packed_qmv_fast_impl(
const device vec<uint32_t, 4>* w,
const device vec<T, 4>* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int packs_per_thread = (bits == 2) ? 1 : 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = 32 / bits;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
typedef float U;
thread U x_thread[values_per_thread];
vec<U, results_per_simdgroup> result = 0;
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size * 2 / group_size;
const int w_row = tid.x * num_simdgroups + simd_gid;
const int out_row = w_row * results_per_simdgroup;
w += w_row * in_vec_size_w + simd_lid * packs_per_thread;
scales += w_row * in_vec_size_g + 2 * (simd_lid / scale_step_per_thread);
x += tid.y * in_vec_size + simd_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
for (int k = 0; k < in_vec_size; k += block_size) {
// Load the input vector
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
// Load the scales and biases
vec<T, 4> s = scales[0];
vec<T, 4> b = scales[1];
// Load the weights and perform the partial dot product
vec<U, 4> accum = 0;
for (int pack = 0; pack < packs_per_thread; pack++) {
accum +=
partial_qdot_vec<T, U, bits>(x_thread + pack * pack_factor, w[pack]);
}
// Finalize the dot product and accumulate it
for (int i = 0; i < 4; i++) {
result[i] += static_cast<U>(s[i]) * accum[i] + static_cast<U>(b[i]) * sum;
}
w += block_size / pack_factor;
scales += 2 * block_size / group_size;
x += block_size;
}
result = simd_sum(result);
if (simd_lid == 0) {
for (int row = 0; row < results_per_simdgroup; row++) {
y[row] = static_cast<T>(result[row]);
}
}
}
template <typename T, int group_size, int bits, int results_per_simdgroup>
METAL_FUNC void affine_packed_byte_qmv_fast_impl(
const device uint8_t* w,
const device vec<T, 2 * results_per_simdgroup>* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int packs_per_thread = 2;
constexpr int num_simdgroups = 2;
constexpr int pack_factor = (bits == 3) ? 8 : 4;
;
constexpr int bytes_per_pack = 3;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
typedef float U;
thread U x_thread[values_per_thread];
vec<U, results_per_simdgroup> result = 0;
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int scales_row = tid.x * num_simdgroups + simd_gid;
const int out_row = scales_row * results_per_simdgroup;
w += out_row * in_vec_size_w + simd_lid * (packs_per_thread * bytes_per_pack);
scales += scales_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.y * in_vec_size + simd_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
for (int k = 0; k < in_vec_size; k += block_size) {
// Load the input vector
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
// Load the scales and biases
vec<T, 2 * results_per_simdgroup> sb = scales[0];
// Load the weights and perform the partial dot product
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] += qdot<U, values_per_thread, bits>(
w + row * in_vec_size_w, x_thread, sb[row], sb[2 + row], sum);
}
w += block_size * bytes_per_pack / pack_factor;
scales += block_size / group_size;
x += block_size;
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(result[row]);
}
}
}
template <typename T, int group_size, int bits>
[[kernel]] void affine_packed_qmv_fast(
const device vec<uint32_t, 4>* w [[buffer(0)]],
const device vec<T, 4>* scales [[buffer(1)]],
const device T* x [[buffer(2)]],
device T* y [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (bits & (bits - 1)) {
affine_packed_byte_qmv_fast_impl<T, group_size, bits, 2>(
(const device uint8_t*)w,
scales,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
} else {
affine_packed_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}
}
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short group_size,
short bits>
struct AffinePackedQuantizedBlockLoader {
static_assert(
BCOLS <= group_size,
"The group size should be larger than the columns");
static_assert(
group_size % BCOLS == 0,
"The group size should be divisible by the columns");
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
MLX_MTL_CONST short pack_factor = 32 / bits;
MLX_MTL_CONST short row_pack_factor = 4;
MLX_MTL_CONST short BCOLS_PACKED = BCOLS * row_pack_factor / pack_factor;
MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor;
MLX_MTL_CONST short TOTAL_INTS = BCOLS_PACKED * BROWS_PACKED;
MLX_MTL_CONST short n_reads =
(TOTAL_INTS < tgp_size) ? 1 : TOTAL_INTS / tgp_size;
MLX_MTL_CONST short group_steps = group_size / BCOLS;
static_assert(
n_reads <= row_pack_factor,
"The loader only supports per thread reads <= row_pack_factor");
const int src_ld;
const int tile_stride;
short group_step_cnt;
const int group_stride;
const short thread_idx;
const short bi;
const short bj;
const short bii;
const short bjj;
const device uint32_t* src;
const device T* scales;
const device T* biases;
threadgroup T* dst;
AffinePackedQuantizedBlockLoader(
const device uint32_t* src_,
const device T* scales_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS_PACKED : BROWS_PACKED * src_ld),
group_step_cnt(0),
group_stride(BROWS_PACKED * 2 * src_ld / group_size),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(n_reads * thread_idx / BCOLS_PACKED),
bj((n_reads * thread_idx) % BCOLS_PACKED),
bii(bi * row_pack_factor + bj % row_pack_factor),
bjj(bj / row_pack_factor),
src(src_ + bi * src_ld * row_pack_factor / pack_factor + bj),
scales(
scales_ + bi * 2 * src_ld * row_pack_factor / group_size +
bj % row_pack_factor),
biases(scales + row_pack_factor),
dst(dst_ + bii * dst_ld + bjj * pack_factor) {}
void load_unsafe() const {
if (bits == 2 && BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
return;
}
for (int i = 0; i < n_reads; i++) {
T scale = scales[i];
T bias = biases[i];
dequantize<T, pack_factor, bits>(
(const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld);
}
}
void load_safe(short2 src_tile_dim) const {
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
return;
}
if (reduction_dim == 1 && bii >= src_tile_dim.y) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
if (reduction_dim == 0 && bii >= src_tile_dim.x) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
for (int i = 0; i < n_reads; i++) {
T scale = scales[i];
T bias = biases[i];
dequantize<T, pack_factor, bits>(
(const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld);
}
}
void next() {
src += tile_stride;
if (reduction_dim == 1) {
if (group_steps > 1) {
group_step_cnt++;
if (group_step_cnt == group_steps) {
group_step_cnt = 0;
scales += (2 * row_pack_factor);
biases += (2 * row_pack_factor);
}
} else {
scales += (2 * row_pack_factor);
biases += (2 * row_pack_factor);
}
} else {
scales += group_stride;
biases += group_stride;
}
}
};
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short group_size,
short bits>
struct AffineScalesPackedQuantizedBlockLoader {
static_assert(
BCOLS <= group_size,
"The group size should be larger than the columns");
static_assert(
group_size % BCOLS == 0,
"The group size should be divisible by the columns");
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
"Template undefined for bits not in {2, 3, 4, 6, 8}");
MLX_MTL_CONST short bytes_per_pack = (bits & (bits - 1)) ? 3 : 4;
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
MLX_MTL_CONST short row_pack_factor = 2;
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor;
MLX_MTL_CONST short TOTAL_READS = BCOLS * BROWS / pack_factor;
MLX_MTL_CONST short n_reads =
(TOTAL_READS < tgp_size) ? 1 : TOTAL_READS / tgp_size;
MLX_MTL_CONST short group_steps = group_size / BCOLS;
const int src_ld;
const int tile_stride;
short group_step_cnt;
const int group_stride;
const short thread_idx;
const short bi;
const short bj;
const short bii;
const device uint8_t* src;
const device T* scales;
const device T* biases;
threadgroup T* dst;
AffineScalesPackedQuantizedBlockLoader(
const device uint32_t* src_,
const device T* scales_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(
reduction_dim ? BCOLS_PACKED * bytes_per_pack
: BROWS * src_ld * bytes_per_pack / pack_factor),
group_step_cnt(0),
group_stride(BROWS_PACKED * 2 * src_ld / group_size),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(n_reads * thread_idx / BCOLS_PACKED),
bj((n_reads * thread_idx) % BCOLS_PACKED),
bii(bi / row_pack_factor),
src(((const device uint8_t*)src_) +
bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack),
scales(
scales_ + bii * 2 * src_ld * row_pack_factor / group_size +
bi % row_pack_factor),
biases(scales + row_pack_factor),
dst(dst_ + bi * dst_ld + bj * pack_factor) {}
void load_unsafe() const {
if (bits == 2 && TOTAL_READS < tgp_size && bi >= BROWS) {
return;
}
T scale = *scales;
T bias = *biases;
for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>(
(const device uint8_t*)(src + bytes_per_pack * i),
scale,
bias,
dst + i * pack_factor);
}
}
void load_safe(short2 src_tile_dim) const {
if (TOTAL_READS < tgp_size && bi >= BROWS) {
return;
}
if (reduction_dim == 1 && bii >= src_tile_dim.y) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
if (reduction_dim == 0 && bii >= src_tile_dim.x) {
for (int i = 0; i < n_reads * pack_factor; i++) {
dst[i] = T(0);
}
return;
}
for (int i = 0; i < n_reads; i++) {
T scale = scales[i];
T bias = biases[i];
dequantize<T, pack_factor, bits>(
(const device uint8_t*)(src + bytes_per_pack * i * src_ld),
scale,
bias,
dst + i * dst_ld);
}
}
void next() {
src += tile_stride;
if (reduction_dim == 1) {
if (group_steps > 1) {
group_step_cnt++;
if (group_step_cnt == group_steps) {
group_step_cnt = 0;
scales += (2 * row_pack_factor);
biases += (2 * row_pack_factor);
}
} else {
scales += (2 * row_pack_factor);
biases += (2 * row_pack_factor);
}
} else {
scales += group_stride;
biases += group_stride;
}
}
};
template <
typename T,
const int group_size,
const int bits,
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void affine_packed_qmm_t_impl(
const device uint32_t* w,
const device T* scales,
const device T* x,
device T* y,
threadgroup T* Xs,
threadgroup T* Ws,
const constant int& K,
const constant int& N,
const constant int& M,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
(void)lid;
constexpr bool power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int WM = 2;
constexpr int WN = 2;
constexpr int pack_factor = 32 / bits;
constexpr int row_pack_factor = (power_of_2_bits) ? 4 : 2;
constexpr int BK_padded = (BK + 16 / sizeof(T));
// Instantiate the appropriate BlockMMA and Loader
using mma_t = mlx::steel::
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
using loader_x_t =
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
using loader_fully_packed_t = AffinePackedQuantizedBlockLoader<
T,
BN,
BK,
BK_padded,
1,
WM * WN * SIMD_SIZE,
group_size,
bits>;
using loader_scales_packed_t = AffineScalesPackedQuantizedBlockLoader<
T,
BN,
BK,
BK_padded,
1,
WM * WN * SIMD_SIZE,
group_size,
bits>;
using loader_w_t = typename ConditionalType<
power_of_2_bits,
loader_fully_packed_t,
loader_scales_packed_t>::type;
// Set the block
const int K_w =
(power_of_2_bits) ? K * row_pack_factor / pack_factor : K * bits / 32;
const int K_g = K * 2 * row_pack_factor / group_size;
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
const int packed_y_col = tid.x * (BN / row_pack_factor);
x += y_row * K;
w += (power_of_2_bits) ? packed_y_col * K_w : y_col * K_w;
scales += packed_y_col * K_g;
y += y_row * N + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
const short num_outs = min(BN, N - y_col);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
loader_w_t loader_w(w, scales, K, Ws, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
if (num_els < BM) {
if (!aligned_N && num_outs < BN) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_safe(short2(BK, num_outs));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
} else {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_safe(short2(BK, num_els));
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
} else {
if (!aligned_N && num_outs < BN) {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_safe(short2(BK, num_outs));
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
} else {
for (int k = 0; k < K; k += BK) {
threadgroup_barrier(mem_flags::mem_threadgroup);
loader_x.load_unsafe();
loader_w.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(Xs, Ws);
loader_x.next();
loader_w.next();
}
}
}
// Store results to device memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (num_els < BM || num_outs < BN) {
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
} else {
mma_op.store_result(y, N);
}
}
template <
typename T,
const int group_size,
const int bits,
const bool aligned_N,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void affine_packed_qmm_t(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* x [[buffer(2)]],
device T* y [[buffer(3)]],
const constant int& K [[buffer(4)]],
const constant int& N [[buffer(5)]],
const constant int& M [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant int64_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant int64_t* w_strides [[buffer(12)]],
const constant int64_t* s_strides [[buffer(13)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr int BK_padded = (BK + 16 / sizeof(T));
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
y,
M * N,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
tid);
}
affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}

View File

@@ -60,6 +60,14 @@
bits, \
split_k)
#define instantiate_quantized_affine_packed(name, type, group_size, bits) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits, \
name, \
type, \
group_size, \
bits)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0)
@@ -96,12 +104,20 @@
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_all_affine_packed(type, group_size, bits) \
instantiate_quantized_affine_packed(affine_packed_qmv_fast, type, group_size, bits) \
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, true) \
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, false) \
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, true) \
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, false)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits)
instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_all_affine_packed(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \

View File

@@ -377,10 +377,187 @@ void qmm_op(
s);
}
void affine_packed_qmv(
const std::vector<array>& inputs,
array& out,
int B,
int D,
int O,
int group_size,
int bits,
const Stream& s) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& d = metal::device(s.device);
auto ensure_row_contiguous_last_dims = [&d, &s](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
d.add_temporary(arr_copy, s.index);
return arr_copy;
}
};
auto x = ensure_row_contiguous_last_dims(inputs[0]);
auto w = ensure_row_contiguous_last_dims(inputs[1]);
auto scales = ensure_row_contiguous_last_dims(inputs[2]);
const bool pow2_bits = (bits & (bits - 1)) == 0;
const int n_simdgroups = 2;
const int results_per_simdgroup = (pow2_bits) ? 4 : 2;
MTL::Size group_dims(32, n_simdgroups, 1);
MTL::Size grid_dims(O / n_simdgroups / results_per_simdgroup, B, 1);
std::string name;
name.reserve(64);
concatenate(
name,
(D % 512 == 0) ? "affine_packed_qmv_fast_" : "affine_packed_qmv_",
get_type_string(out.dtype()),
"_gs_",
std::to_string(group_size),
"_b_",
std::to_string(bits));
auto kernel = get_quantized_kernel(d, name, "");
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(x, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(D, 4);
compute_encoder.set_bytes(O, 5);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void affine_packed_qmm_t(
const std::vector<array>& inputs,
array& out,
bool batched,
int B,
int D,
int O,
int group_size,
int bits,
const Stream& s) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& d = metal::device(s.device);
auto ensure_row_contiguous_last_dims = [&d, &s](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
d.add_temporary(arr_copy, s.index);
return arr_copy;
}
};
// TODO: Deal with this in routing towards qmm_n instead of qmm_t
auto x = ensure_row_contiguous_last_dims(inputs[0]);
auto w = ensure_row_contiguous_last_dims(inputs[1]);
auto scales = ensure_row_contiguous_last_dims(inputs[2]);
int x_batch_ndims = x.ndim() - 2;
auto& x_shape = x.shape();
auto& x_strides = x.strides();
int w_batch_ndims = w.ndim() - 2;
auto& w_shape = w.shape();
auto& w_strides = w.strides();
auto& s_strides = scales.strides();
const int wn = 2;
const int wm = 2;
const int bm = 32;
const int bn = 32;
const int N = (batched) ? out.size() / B / O : 1;
MTL::Size group_dims(32, wn, wm);
MTL::Size grid_dims((O + bn - 1) / bn, (B + bm - 1) / bm, N);
std::string name;
name.reserve(64);
concatenate(
name,
"affine_packed_qmm_t_",
get_type_string(out.dtype()),
"_gs_",
std::to_string(group_size),
"_b_",
std::to_string(bits),
"_alN_",
((O % 32) == 0) ? "true" : "false",
"_batch_",
(batched) ? "true" : "false");
auto kernel = get_quantized_kernel(d, name, "");
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(x, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(D, 4);
compute_encoder.set_bytes(O, 5);
compute_encoder.set_bytes(B, 6);
if (batched) {
compute_encoder.set_bytes(x_batch_ndims, 7);
compute_encoder.set_vector_bytes(x_shape, 8);
compute_encoder.set_vector_bytes(x_strides, 9);
compute_encoder.set_bytes(w_batch_ndims, 10);
compute_encoder.set_vector_bytes(w_shape, 11);
compute_encoder.set_vector_bytes(w_strides, 12);
compute_encoder.set_vector_bytes(s_strides, 13);
}
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void affine_packed_qmm_op(
const std::vector<array>& inputs,
array& out,
bool transpose,
int group_size,
int bits,
const Stream& s) {
auto& x = inputs[0];
auto& w = inputs[1];
bool batched = w.ndim() > 2;
int D = x.shape(-1);
int O = out.shape(-1);
int B = (batched) ? x.shape(-2) : x.size() / D;
if (transpose) {
if (B < 6) {
affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s);
} else {
affine_packed_qmm_t(inputs, out, batched, B, D, O, group_size, bits, s);
}
} else {
}
}
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
qmm_op(
inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream());
if (type_ == QuantizationType::Affine) {
assert(inputs.size() == 4);
qmm_op(
inputs,
out,
transpose_,
group_size_,
bits_,
/*gather=*/false,
stream());
}
if (type_ == QuantizationType::AffinePacked) {
assert(inputs.size() == 3);
affine_packed_qmm_op(inputs, out, transpose_, group_size_, bits_, stream());
}
}
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@@ -75,10 +75,35 @@ std::pair<int, int> extract_quantized_matmul_dims(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
bool transpose,
int group_size,
int bits) {
int bits,
QuantizationType quantization_type) {
// Check if we have biases as expected
switch (quantization_type) {
case QuantizationType::Affine:
if (!biases.has_value()) {
std::ostringstream msg;
msg << "[" << tag
<< "] The biases argument is required for quantization "
<< "type '" << quantization_type << "'";
throw std::invalid_argument(msg.str());
}
break;
case QuantizationType::AffinePacked:
if (biases.has_value()) {
std::ostringstream msg;
msg << "[" << tag << "] Quantization type '" << quantization_type
<< "' does not use "
<< "biases but biases were provided";
throw std::invalid_argument(msg.str());
}
break;
}
bool pow2_bits = (bits & (bits - 1)) == 0;
if (w.dtype() != uint32) {
std::ostringstream msg;
msg << "[" << tag << "] The weight matrix should be uint32 "
@@ -86,11 +111,11 @@ std::pair<int, int> extract_quantized_matmul_dims(
throw std::invalid_argument(msg.str());
}
if (scales.shape() != biases.shape()) {
if (biases.has_value() && scales.shape() != biases.value().shape()) {
std::ostringstream msg;
msg << "[" << tag << "] Scales and biases should have the same shape. "
<< "Received scales with shape " << scales.shape()
<< " and biases with " << biases.shape();
<< " and biases with " << biases.value().shape();
throw std::invalid_argument(msg.str());
}
@@ -99,25 +124,42 @@ std::pair<int, int> extract_quantized_matmul_dims(
std::ostringstream msg;
msg << "[" << tag
<< "] Weight, scales and biases should have the same batch shape. "
<< "Received weight with shape " << w.shape() << ", scales with "
<< scales.shape() << " and biases with " << biases.shape();
<< "Received weight with shape " << w.shape()
<< " and scales/biases with " << scales.shape();
throw std::invalid_argument(msg.str());
}
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {
int weight_dims = w.shape(-1) * 32 / bits;
int scales_dims = scales.shape(-1) * group_size;
if (quantization_type == QuantizationType::AffinePacked) {
if (pow2_bits) {
scales_dims /= 8;
weight_dims /= 4;
} else {
scales_dims /= 4;
}
}
if (weight_dims != scales_dims) {
std::ostringstream msg;
msg << "[" << tag << "] The shapes of the weight and scales are "
<< "incompatible based on bits and group_size. w.shape() == "
<< w.shape() << " and scales.shape() == " << scales.shape()
<< " with group_size=" << group_size << " and bits=" << bits;
<< "incompatible based on bits, group_size and quantization type. "
<< "w.shape() == " << w.shape()
<< " and scales.shape() == " << scales.shape()
<< " with group_size=" << group_size << ", bits=" << bits
<< " and type='" << quantization_type << "'";
throw std::invalid_argument(msg.str());
}
int x_inner_dims = x.shape(-1);
// Calculate the expanded w's dims
int w_inner_dims = (transpose) ? w.shape(-1) * 32 / bits : w.shape(-2);
int w_outer_dims = (transpose) ? w.shape(-2) : w.shape(-1) * 32 / bits;
int weight_dims_other = w.shape(-2);
if (quantization_type == QuantizationType::AffinePacked && pow2_bits) {
weight_dims_other *= 4;
}
int w_inner_dims = (transpose) ? weight_dims : weight_dims_other;
int w_outer_dims = (transpose) ? weight_dims_other : weight_dims;
if (w_inner_dims != x_inner_dims) {
std::ostringstream msg;
@@ -3662,14 +3704,23 @@ array quantized_matmul(
array x,
array w,
array scales,
array biases,
std::optional<array> biases,
bool transpose /* = true */,
int group_size /* = 64 */,
int bits /* = 4 */,
QuantizationType quantization_type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) {
// Check and extract the quantized matrix shape against x
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
"quantized_matmul",
x,
w,
scales,
biases,
transpose,
group_size,
bits,
quantization_type);
// QuantizedMatmul handles w.ndim == 2 case.
if (x.ndim() > 2 && w.ndim() > 2) {
@@ -3690,69 +3741,134 @@ array quantized_matmul(
*(inner_shape.end() - 1) = scales.shape(-1);
scales = broadcast_to(scales, inner_shape, s);
*(inner_shape.end() - 1) = biases.shape(-1);
biases = broadcast_to(biases, inner_shape, s);
if (biases.has_value()) {
*(inner_shape.end() - 1) = biases.value().shape(-1);
biases = broadcast_to(biases.value(), inner_shape, s);
}
}
auto dtype = result_type(x, scales, biases);
auto dtype = result_type(x, scales);
if (biases.has_value()) {
dtype = promote_types(dtype, biases.value().dtype());
}
if (!issubdtype(dtype, floating)) {
std::ostringstream msg;
msg << "[quantized_matmul] Only real floating types are supported but "
<< "the passed types where x.dtype() == " << x.dtype()
<< ", scales.dtype() == " << scales.dtype()
<< " and biases.dtype() == " << biases.dtype();
<< ", scales.dtype() == " << scales.dtype();
if (biases.has_value()) {
msg << " and biases.dtype() == " << biases.value().dtype();
}
throw std::invalid_argument(msg.str());
}
// Prepare the inputs vector
std::vector<array> inputs;
inputs.reserve(4);
inputs.push_back(astype(x, dtype, s));
inputs.push_back(w);
inputs.push_back(astype(scales, dtype, s));
if (biases.has_value()) {
inputs.push_back(astype(biases.value(), dtype, s));
}
auto out_shape = x.shape();
out_shape.back() = w_outer_dims;
return array(
std::move(out_shape),
dtype,
std::make_shared<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose),
{astype(x, dtype, s),
w,
astype(scales, dtype, s),
astype(biases, dtype, s)});
to_stream(s), quantization_type, group_size, bits, transpose),
std::move(inputs));
}
std::tuple<array, array, array> quantize(
std::tuple<array, array, std::optional<array>> quantize(
const array& w,
int group_size /* = 64 */,
int bits /* = 4 */,
QuantizationType quantization_type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) {
return fast::affine_quantize(w, group_size, bits, s);
switch (quantization_type) {
case QuantizationType::Affine:
return fast::affine_quantize(w, group_size, bits, s);
case QuantizationType::AffinePacked: {
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
int pow2_bits = (bits & (bits - 1)) == 0;
int row_packing = (pow2_bits) ? 4 : 2;
scales = unflatten(scales, -2, {-1, row_packing}, s);
biases = unflatten(biases, -2, {-1, row_packing}, s);
scales = concatenate({scales, biases}, -2, s);
scales = moveaxis(scales, -2, -1, s);
scales = flatten(scales, -2, -1, s);
if (pow2_bits) {
wq = unflatten(wq, -2, {-1, row_packing}, s);
wq = moveaxis(wq, -2, -1, s);
wq = flatten(wq, -2, -1, s);
}
return std::make_tuple(wq, scales, std::nullopt);
}
}
}
array dequantize(
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
int group_size /* = 64 */,
int bits /* = 4 */,
QuantizationType quantization_type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) {
return fast::affine_dequantize(w, scales, biases, group_size, bits, s);
return fast::affine_dequantize(
w, scales, biases.value(), group_size, bits, s);
}
array gather_qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
std::optional<array> lhs_indices_ /* = std::nullopt */,
std::optional<array> rhs_indices_ /* = std::nullopt */,
const std::optional<array>& biases,
const std::optional<array>& lhs_indices_ /* = std::nullopt */,
const std::optional<array>& rhs_indices_ /* = std::nullopt */,
bool transpose /* = true */,
int group_size /* = 64 */,
int bits /* = 4 */,
QuantizationType quantization_type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) {
if (!lhs_indices_ && !rhs_indices_) {
return quantized_matmul(
x, w, scales, biases, transpose, group_size, bits, s);
x,
w,
scales,
biases,
transpose,
group_size,
bits,
quantization_type,
s);
}
if (quantization_type != QuantizationType::Affine) {
std::ostringstream msg;
msg << "[gather_qmm] Only quantization type '" << quantization_type
<< "' supported";
throw std::invalid_argument(msg.str());
}
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
"gather_qmm",
x,
w,
scales,
biases,
transpose,
group_size,
bits,
quantization_type);
// Extract indices and broadcast them
array lhs_indices = indices_or_default(lhs_indices_, x, s);
@@ -3768,16 +3884,17 @@ array gather_qmm(
out_shape.push_back(w_outer_dims);
// and output type
auto out_type = result_type(x, scales, biases);
auto out_type = result_type(x, scales, biases.value());
return array(
std::move(out_shape),
out_type,
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
std::make_shared<GatherQMM>(
to_stream(s), quantization_type, group_size, bits, transpose),
{astype(x, out_type, s),
w,
astype(scales, out_type, s),
astype(biases, out_type, s),
astype(biases.value(), out_type, s),
lhs_indices,
rhs_indices});
}

View File

@@ -1277,31 +1277,34 @@ array conv_transpose3d(
int groups = 1,
StreamOrDevice s = {});
/** Quantized matmul multiplies x with a quantized matrix w*/
/** Quantized matmul multiplies x with a quantized matrix w */
array quantized_matmul(
array x,
array w,
array scales,
array biases,
std::optional<array> biases,
bool transpose = true,
int group_size = 64,
int bits = 4,
QuantizationType quantization_type = QuantizationType::Affine,
StreamOrDevice s = {});
/** Quantize a matrix along its last axis */
std::tuple<array, array, array> quantize(
std::tuple<array, array, std::optional<array>> quantize(
const array& w,
int group_size = 64,
int bits = 4,
QuantizationType quantization_type = QuantizationType::Affine,
StreamOrDevice s = {});
/** Dequantize a matrix produced by quantize() */
array dequantize(
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
int group_size = 64,
int bits = 4,
QuantizationType quantization_type = QuantizationType::Affine,
StreamOrDevice s = {});
/** Compute matrix products with matrix-level gather. */
@@ -1309,12 +1312,13 @@ array gather_qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
std::optional<array> lhs_indices = std::nullopt,
std::optional<array> rhs_indices = std::nullopt,
const std::optional<array>& biases,
const std::optional<array>& lhs_indices = std::nullopt,
const std::optional<array>& rhs_indices = std::nullopt,
bool transpose = true,
int group_size = 64,
int bits = 4,
QuantizationType quantization_type = QuantizationType::Affine,
StreamOrDevice s = {});
/** Returns a contraction of a and b over multiple dimensions. */

View File

@@ -2777,10 +2777,11 @@ std::vector<array> QuantizedMatmul::vjp(
cotangents[0],
primals[1],
primals[2],
primals[3],
(primals.size() > 3) ? std::optional(primals[3]) : std::nullopt,
!transpose_,
group_size_,
bits_,
type_,
stream()));
}
@@ -2855,6 +2856,7 @@ std::vector<array> GatherQMM::vjp(
!transpose_,
group_size_,
bits_,
type_,
stream()),
-3,
stream()),

View File

@@ -8,6 +8,7 @@
#include "mlx/device.h"
#include "mlx/io/load.h"
#include "mlx/stream.h"
#include "mlx/utils.h"
#define DEFINE_VMAP() \
virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
@@ -1568,10 +1569,12 @@ class QuantizedMatmul : public UnaryPrimitive {
public:
explicit QuantizedMatmul(
Stream stream,
QuantizationType type,
int group_size,
int bits,
bool transpose)
: UnaryPrimitive(stream),
type_(type),
group_size_(group_size),
bits_(bits),
transpose_(transpose) {}
@@ -1586,6 +1589,7 @@ class QuantizedMatmul : public UnaryPrimitive {
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
private:
QuantizationType type_;
int group_size_;
int bits_;
bool transpose_;
@@ -1595,8 +1599,14 @@ class QuantizedMatmul : public UnaryPrimitive {
class GatherQMM : public UnaryPrimitive {
public:
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
explicit GatherQMM(
Stream stream,
QuantizationType type,
int group_size,
int bits,
bool transpose)
: UnaryPrimitive(stream),
type_(type),
group_size_(group_size),
bits_(bits),
transpose_(transpose) {}
@@ -1610,6 +1620,7 @@ class GatherQMM : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
private:
QuantizationType type_;
int group_size_;
int bits_;
bool transpose_;

View File

@@ -145,6 +145,30 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) {
return os;
}
std::ostream& operator<<(std::ostream& os, QuantizationType type) {
std::string_view quantization_type;
switch (type) {
case QuantizationType::Affine:
quantization_type = "affine";
break;
case QuantizationType::AffinePacked:
quantization_type = "affine-packed";
break;
}
return os << quantization_type;
}
QuantizationType from_string(const std::string& s) {
if (s == "affine") {
return QuantizationType::Affine;
}
if (s == "affine-packed") {
return QuantizationType::AffinePacked;
}
throw std::invalid_argument("Cannot map '" + s + "' to a quantization type");
}
namespace {
inline size_t

View File

@@ -100,6 +100,19 @@ inline int next_power_of_2(int n) {
return pow(2, std::ceil(std::log2(n)));
}
/** Enumerate the different quantization types */
enum class QuantizationType {
// Traditional affine quantization with separate scales and biases
Affine,
// The same quantization as affine but with the scales and biases packed in a
// single array and contiguously every 4 rows
AffinePacked,
};
std::ostream& operator<<(std::ostream& os, QuantizationType type);
QuantizationType from_string(const std::string& s);
namespace env {
int get_var(const char* name, int default_value);

View File

@@ -39,6 +39,10 @@ class Embedding(Module):
"""
return x @ self.weight.T
def to_quantized(self, group_size: int = 64, bits: int = 4):
def to_quantized(
self, group_size: int = 64, bits: int = 4, quantization_type: str = "affine"
):
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
return QuantizedEmbedding.from_embedding(self, group_size, bits)
return QuantizedEmbedding.from_embedding(
self, group_size, bits, quantization_type
)

View File

@@ -70,9 +70,11 @@ class Linear(Module):
x = x @ self["weight"].T
return x
def to_quantized(self, group_size: int = 64, bits: int = 4):
def to_quantized(
self, group_size: int = 64, bits: int = 4, quantization_type: str = "affine"
):
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
return QuantizedLinear.from_linear(self, group_size, bits)
return QuantizedLinear.from_linear(self, group_size, bits, quantization_type)
class Bilinear(Module):

View File

@@ -12,6 +12,7 @@ def quantize(
model: Module,
group_size: int = 64,
bits: int = 4,
quantization_type: str = "affine",
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
):
"""Quantize the sub-modules of a module according to a predicate.
@@ -39,7 +40,11 @@ def quantize(
if bool_or_params := class_predicate(path, m):
if hasattr(m, "to_quantized"):
if isinstance(bool_or_params, bool):
return m.to_quantized(group_size=group_size, bits=bits)
return m.to_quantized(
group_size=group_size,
bits=bits,
quantization_type=quantization_type,
)
elif isinstance(bool_or_params, dict):
return m.to_quantized(**bool_or_params)
else:
@@ -131,9 +136,15 @@ class QuantizedEmbedding(Module):
@classmethod
def from_embedding(
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
cls,
embedding_layer: Module,
group_size: int = 64,
bits: int = 4,
quantization_type: str = "affine",
):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
if quantization_type != "affine":
raise ValueError(f"Quantization type {quantization_type} not supported")
embedding_dims, dims = embedding_layer.weight.shape
ql = cls(embedding_dims, dims, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
@@ -170,12 +181,14 @@ class QuantizedLinear(Module):
bias: bool = True,
group_size: int = 64,
bits: int = 4,
quantization_type: str = "affine",
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
self.quantization_type = quantization_type
# Initialize the quantized weight
scale = math.sqrt(1 / input_dims)
@@ -184,7 +197,9 @@ class QuantizedLinear(Module):
high=scale,
shape=(output_dims, input_dims),
)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
self.weight, self.scales, self.biases = mx.quantize(
weight, group_size, bits, quantization_type=quantization_type
)
# And bias if needed
if bias:
@@ -212,22 +227,29 @@ class QuantizedLinear(Module):
x,
self["weight"],
scales=self["scales"],
biases=self["biases"],
biases=self.get("biases", None),
transpose=True,
group_size=self.group_size,
bits=self.bits,
quantization_type=self.quantization_type,
)
if "bias" in self:
x = x + self["bias"]
return x
@classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
def from_linear(
cls,
linear_layer: Module,
group_size: int = 64,
bits: int = 4,
quantization_type: str = "affine",
):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)
ql = cls(input_dims, output_dims, False, group_size, bits, quantization_type)
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits
linear_layer.weight, group_size, bits, quantization_type=quantization_type
)
if "bias" in linear_layer:
ql.bias = linear_layer.bias

View File

@@ -4018,18 +4018,38 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"quantized_matmul",
&mx::quantized_matmul,
[](mx::array x,
mx::array w,
mx::array scales,
std::optional<mx::array> biases,
bool transpose,
int group_size,
int bits,
const std::string& quantization_type,
mx::StreamOrDevice s) {
return mx::quantized_matmul(
std::move(x),
std::move(w),
std::move(scales),
std::move(biases),
transpose,
group_size,
bits,
mx::from_string(quantization_type),
s);
},
nb::arg(),
nb::arg(),
"scales"_a,
"biases"_a,
"biases"_a = nb::none(),
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"quantization_type"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``group_size`` of
@@ -4040,7 +4060,8 @@ void init_ops(nb::module_& m) {
x (array): Input array
w (array): Quantized matrix packed in unsigned integers
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
biases (array, optional): The biases to use per ``group_size``
elements of ``w`` depending on the quantization type
transpose (bool, optional): Defines whether to multiply with the
transposed ``w`` or not, namely whether we are performing
``x @ w.T`` or ``x @ w``. Default: ``True``.
@@ -4048,20 +4069,30 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
quantization_type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'.
Returns:
array: The result of the multiplication of ``x`` with ``w``.
)pbdoc");
m.def(
"quantize",
&mx::quantize,
[](const mx::array& w,
int group_size,
int bits,
const std::string& quantization_type,
mx::StreamOrDevice s) {
return mx::quantize(
w, group_size, bits, mx::from_string(quantization_type), s);
},
nb::arg(),
"group_size"_a = 64,
"bits"_a = 4,
"quantization_type"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"),
R"pbdoc(
Quantize the matrix ``w`` using ``bits`` bits per element.
@@ -4103,26 +4134,46 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. Default: ``4``.
quantization_type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'.
Returns:
tuple: A tuple containing
* w_q (array): The quantized version of ``w``
* scales (array): The scale to multiply each element with, namely :math:`s`
* biases (array): The biases to add to each element, namely :math:`\beta`
* biases (array, optional): The biases to add to each element, namely
* :math:`\beta`. Depending on the quantization type this return value
may be None.
)pbdoc");
m.def(
"dequantize",
&mx::dequantize,
[](const mx::array& wq,
const mx::array& scales,
const std::optional<mx::array>& biases,
int group_size,
int bits,
const std::string& quantization_type,
mx::StreamOrDevice s) {
return mx::dequantize(
wq,
scales,
biases,
group_size,
bits,
mx::from_string(quantization_type),
s);
},
nb::arg(),
"scales"_a,
"biases"_a,
"group_size"_a = 64,
"bits"_a = 4,
"quantization_type"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Dequantize the matrix ``w`` using the provided ``scales`` and
``biases`` and the ``group_size`` and ``bits`` configuration.
@@ -4143,6 +4194,8 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
quantization_type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'.
Returns:
array: The dequantized version of ``w``
@@ -4153,16 +4206,17 @@ void init_ops(nb::module_& m) {
nb::arg(),
nb::arg(),
"scales"_a,
"biases"_a,
"biases"_a = nb::none(),
"lhs_indices"_a = nb::none(),
"rhs_indices"_a = nb::none(),
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"quantization_type"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array], lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather.
@@ -4188,6 +4242,8 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
quantization_type (str, optional): The type of quantization used for the matrix.
It can be 'affine' or 'affine-packed'.
Returns:
array: The result of the multiplication of ``x`` with ``w``