mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-12 15:24:57 +08:00
Compare commits
19 Commits
sdpa-test
...
packed-qua
Author | SHA1 | Date | |
---|---|---|---|
![]() |
c02e14c264 | ||
![]() |
d75a509234 | ||
![]() |
14420949d2 | ||
![]() |
4847199ec6 | ||
![]() |
fb7be036af | ||
![]() |
410ccdbed5 | ||
![]() |
f5da489a3c | ||
![]() |
c2e6d58441 | ||
![]() |
17a1fa2f0b | ||
![]() |
fd161aa31f | ||
![]() |
bf6dc54110 | ||
![]() |
d7ed624502 | ||
![]() |
05cb54ae3f | ||
![]() |
cb358dbdda | ||
![]() |
e4b587819c | ||
![]() |
a06c968f4d | ||
![]() |
651c510940 | ||
![]() |
11ec07ff9d | ||
![]() |
bdd68bd893 |
74
benchmarks/python/packed_qmm_bench.py
Normal file
74
benchmarks/python/packed_qmm_bench.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
B = 1024
|
||||
D = 1024
|
||||
M = 4 * D
|
||||
group_size = 64
|
||||
bits = 4
|
||||
dtype = mx.float16
|
||||
loops = 10
|
||||
|
||||
|
||||
def qmm_(x, wq1, wq2, q_type):
|
||||
for i in range(loops):
|
||||
x = mx.quantized_matmul(
|
||||
x,
|
||||
*wq1,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
quantization_type=q_type,
|
||||
)
|
||||
x = mx.quantized_matmul(
|
||||
x,
|
||||
*wq2,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
quantization_type=q_type,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def affine_qmm(x, wq1, wq2):
|
||||
return qmm_(x, wq1, wq2, "affine")
|
||||
|
||||
|
||||
def affine_packed_qmm(x, wq1, wq2):
|
||||
return qmm_(x, wq1, wq2, "affine-packed")
|
||||
|
||||
|
||||
def time_qmm():
|
||||
mx.random.seed(3)
|
||||
x = mx.random.normal(shape=(B, D)).astype(dtype)
|
||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
|
||||
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
|
||||
mx.eval(x, wq1, wq2)
|
||||
time_fn(affine_qmm, x, wq1, wq2)
|
||||
|
||||
|
||||
def time_packed_qmm():
|
||||
mx.random.seed(3)
|
||||
x = mx.random.normal(shape=(B, D)).astype(dtype)
|
||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||
wq1 = mx.quantize(
|
||||
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||
)
|
||||
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||
wq2 = mx.quantize(
|
||||
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||
)
|
||||
mx.eval(x, wq1, wq2)
|
||||
time_fn(affine_packed_qmm, x, wq1, wq2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for b in [2, 4, 8]:
|
||||
bits = b
|
||||
print(f"Bits {bits}:")
|
||||
time_qmm()
|
||||
time_packed_qmm()
|
@@ -1248,6 +1248,41 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void adjust_matrix_offsets(
|
||||
const device T*& x,
|
||||
const device uint32_t*& w,
|
||||
const device T*& scales,
|
||||
device T*& y,
|
||||
int output_stride,
|
||||
const constant int& x_batch_ndims,
|
||||
const constant int* x_shape,
|
||||
const constant int64_t* x_strides,
|
||||
const constant int& w_batch_ndims,
|
||||
const constant int* w_shape,
|
||||
const constant int64_t* w_strides,
|
||||
const constant int64_t* s_strides,
|
||||
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||
// Set the input/output matrices
|
||||
uint32_t x_idx = tid.z;
|
||||
uint32_t w_idx = tid.z;
|
||||
if (x_batch_ndims == 1) {
|
||||
x += x_idx * x_strides[0];
|
||||
} else {
|
||||
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
|
||||
}
|
||||
if (w_batch_ndims == 1) {
|
||||
w += w_idx * w_strides[0];
|
||||
scales += w_idx * s_strides[0];
|
||||
} else {
|
||||
ulong2 idx = elem_to_loc_broadcast(
|
||||
w_idx, w_shape, w_strides, s_strides, w_batch_ndims);
|
||||
w += idx.x;
|
||||
scales += idx.y;
|
||||
}
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void adjust_matrix_offsets(
|
||||
const device T*& x,
|
||||
@@ -2149,3 +2184,666 @@ template <typename T, const int group_size, const int bits>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, int bits>
|
||||
inline vec<U, 4> partial_qdot_vec(const thread U* x, vec<uint32_t, 4> w) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
|
||||
vec<U, 4> accum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
auto ws = as_type<vec<uint8_t, 4>>(w[i]);
|
||||
for (int j = 0; j < 4; j++) {
|
||||
accum[i] +=
|
||||
(x[4 * j + 0] * (ws[j] & 0x03) + x[4 * j + 1] * (ws[j] & 0x0c) +
|
||||
x[4 * j + 2] * (ws[j] & 0x30) + x[4 * j + 3] * (ws[j] & 0xc0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
auto ws = as_type<vec<uint16_t, 2>>(w[i]);
|
||||
for (int j = 0; j < 2; j++) {
|
||||
accum[i] +=
|
||||
(x[4 * j + 0] * (ws[j] & 0x000f) + x[4 * j + 1] * (ws[j] & 0x00f0) +
|
||||
x[4 * j + 2] * (ws[j] & 0x0f00) + x[4 * j + 3] * (ws[j] & 0xf000));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
auto ws = as_type<vec<uint8_t, 4>>(w[i]);
|
||||
for (int j = 0; j < 4; j++) {
|
||||
accum[i] += x[j] * ws[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return accum;
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
METAL_FUNC void affine_packed_qmv_fast_impl(
|
||||
const device vec<uint32_t, 4>* w,
|
||||
const device vec<T, 4>* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int packs_per_thread = (bits == 2) ? 1 : 2;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
vec<U, results_per_simdgroup> result = 0;
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size * 2 / group_size;
|
||||
const int w_row = tid.x * num_simdgroups + simd_gid;
|
||||
const int out_row = w_row * results_per_simdgroup;
|
||||
|
||||
w += w_row * in_vec_size_w + simd_lid * packs_per_thread;
|
||||
scales += w_row * in_vec_size_g + 2 * (simd_lid / scale_step_per_thread);
|
||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
// Load the input vector
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
// Load the scales and biases
|
||||
vec<T, 4> s = scales[0];
|
||||
vec<T, 4> b = scales[1];
|
||||
|
||||
// Load the weights and perform the partial dot product
|
||||
vec<U, 4> accum = 0;
|
||||
for (int pack = 0; pack < packs_per_thread; pack++) {
|
||||
accum +=
|
||||
partial_qdot_vec<T, U, bits>(x_thread + pack * pack_factor, w[pack]);
|
||||
}
|
||||
|
||||
// Finalize the dot product and accumulate it
|
||||
for (int i = 0; i < 4; i++) {
|
||||
result[i] += static_cast<U>(s[i]) * accum[i] + static_cast<U>(b[i]) * sum;
|
||||
}
|
||||
|
||||
w += block_size / pack_factor;
|
||||
scales += 2 * block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
|
||||
result = simd_sum(result);
|
||||
if (simd_lid == 0) {
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
y[row] = static_cast<T>(result[row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, int results_per_simdgroup>
|
||||
METAL_FUNC void affine_packed_byte_qmv_fast_impl(
|
||||
const device uint8_t* w,
|
||||
const device vec<T, 2 * results_per_simdgroup>* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
const constant int& in_vec_size,
|
||||
const constant int& out_vec_size,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int packs_per_thread = 2;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int pack_factor = (bits == 3) ? 8 : 4;
|
||||
;
|
||||
constexpr int bytes_per_pack = 3;
|
||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||
|
||||
typedef float U;
|
||||
|
||||
thread U x_thread[values_per_thread];
|
||||
vec<U, results_per_simdgroup> result = 0;
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
const int scales_row = tid.x * num_simdgroups + simd_gid;
|
||||
const int out_row = scales_row * results_per_simdgroup;
|
||||
|
||||
w += out_row * in_vec_size_w + simd_lid * (packs_per_thread * bytes_per_pack);
|
||||
scales += scales_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.y * out_vec_size + out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
// Load the input vector
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
// Load the scales and biases
|
||||
vec<T, 2 * results_per_simdgroup> sb = scales[0];
|
||||
|
||||
// Load the weights and perform the partial dot product
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] += qdot<U, values_per_thread, bits>(
|
||||
w + row * in_vec_size_w, x_thread, sb[row], sb[2 + row], sum);
|
||||
}
|
||||
|
||||
w += block_size * bytes_per_pack / pack_factor;
|
||||
scales += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
if (simd_lid == 0) {
|
||||
y[row] = static_cast<T>(result[row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void affine_packed_qmv_fast(
|
||||
const device vec<uint32_t, 4>* w [[buffer(0)]],
|
||||
const device vec<T, 4>* scales [[buffer(1)]],
|
||||
const device T* x [[buffer(2)]],
|
||||
device T* y [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
if (bits & (bits - 1)) {
|
||||
affine_packed_byte_qmv_fast_impl<T, group_size, bits, 2>(
|
||||
(const device uint8_t*)w,
|
||||
scales,
|
||||
x,
|
||||
y,
|
||||
in_vec_size,
|
||||
out_vec_size,
|
||||
tid,
|
||||
simd_gid,
|
||||
simd_lid);
|
||||
} else {
|
||||
affine_packed_qmv_fast_impl<T, group_size, bits>(
|
||||
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short group_size,
|
||||
short bits>
|
||||
struct AffinePackedQuantizedBlockLoader {
|
||||
static_assert(
|
||||
BCOLS <= group_size,
|
||||
"The group size should be larger than the columns");
|
||||
static_assert(
|
||||
group_size % BCOLS == 0,
|
||||
"The group size should be divisible by the columns");
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
|
||||
MLX_MTL_CONST short pack_factor = 32 / bits;
|
||||
MLX_MTL_CONST short row_pack_factor = 4;
|
||||
MLX_MTL_CONST short BCOLS_PACKED = BCOLS * row_pack_factor / pack_factor;
|
||||
MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor;
|
||||
MLX_MTL_CONST short TOTAL_INTS = BCOLS_PACKED * BROWS_PACKED;
|
||||
MLX_MTL_CONST short n_reads =
|
||||
(TOTAL_INTS < tgp_size) ? 1 : TOTAL_INTS / tgp_size;
|
||||
MLX_MTL_CONST short group_steps = group_size / BCOLS;
|
||||
|
||||
static_assert(
|
||||
n_reads <= row_pack_factor,
|
||||
"The loader only supports per thread reads <= row_pack_factor");
|
||||
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
short group_step_cnt;
|
||||
const int group_stride;
|
||||
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
const short bii;
|
||||
const short bjj;
|
||||
|
||||
const device uint32_t* src;
|
||||
const device T* scales;
|
||||
const device T* biases;
|
||||
threadgroup T* dst;
|
||||
|
||||
AffinePackedQuantizedBlockLoader(
|
||||
const device uint32_t* src_,
|
||||
const device T* scales_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(reduction_dim ? BCOLS_PACKED : BROWS_PACKED * src_ld),
|
||||
group_step_cnt(0),
|
||||
group_stride(BROWS_PACKED * 2 * src_ld / group_size),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(n_reads * thread_idx / BCOLS_PACKED),
|
||||
bj((n_reads * thread_idx) % BCOLS_PACKED),
|
||||
bii(bi * row_pack_factor + bj % row_pack_factor),
|
||||
bjj(bj / row_pack_factor),
|
||||
src(src_ + bi * src_ld * row_pack_factor / pack_factor + bj),
|
||||
scales(
|
||||
scales_ + bi * 2 * src_ld * row_pack_factor / group_size +
|
||||
bj % row_pack_factor),
|
||||
biases(scales + row_pack_factor),
|
||||
dst(dst_ + bii * dst_ld + bjj * pack_factor) {}
|
||||
|
||||
void load_unsafe() const {
|
||||
if (bits == 2 && BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
T scale = scales[i];
|
||||
T bias = biases[i];
|
||||
dequantize<T, pack_factor, bits>(
|
||||
(const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld);
|
||||
}
|
||||
}
|
||||
|
||||
void load_safe(short2 src_tile_dim) const {
|
||||
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (reduction_dim == 1 && bii >= src_tile_dim.y) {
|
||||
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||
dst[i] = T(0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (reduction_dim == 0 && bii >= src_tile_dim.x) {
|
||||
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||
dst[i] = T(0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
T scale = scales[i];
|
||||
T bias = biases[i];
|
||||
dequantize<T, pack_factor, bits>(
|
||||
(const device uint8_t*)(src + i), scale, bias, dst + i * dst_ld);
|
||||
}
|
||||
}
|
||||
|
||||
void next() {
|
||||
src += tile_stride;
|
||||
if (reduction_dim == 1) {
|
||||
if (group_steps > 1) {
|
||||
group_step_cnt++;
|
||||
if (group_step_cnt == group_steps) {
|
||||
group_step_cnt = 0;
|
||||
scales += (2 * row_pack_factor);
|
||||
biases += (2 * row_pack_factor);
|
||||
}
|
||||
} else {
|
||||
scales += (2 * row_pack_factor);
|
||||
biases += (2 * row_pack_factor);
|
||||
}
|
||||
} else {
|
||||
scales += group_stride;
|
||||
biases += group_stride;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short group_size,
|
||||
short bits>
|
||||
struct AffineScalesPackedQuantizedBlockLoader {
|
||||
static_assert(
|
||||
BCOLS <= group_size,
|
||||
"The group size should be larger than the columns");
|
||||
static_assert(
|
||||
group_size % BCOLS == 0,
|
||||
"The group size should be divisible by the columns");
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
|
||||
MLX_MTL_CONST short bytes_per_pack = (bits & (bits - 1)) ? 3 : 4;
|
||||
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
MLX_MTL_CONST short row_pack_factor = 2;
|
||||
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
||||
MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor;
|
||||
MLX_MTL_CONST short TOTAL_READS = BCOLS * BROWS / pack_factor;
|
||||
MLX_MTL_CONST short n_reads =
|
||||
(TOTAL_READS < tgp_size) ? 1 : TOTAL_READS / tgp_size;
|
||||
MLX_MTL_CONST short group_steps = group_size / BCOLS;
|
||||
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
short group_step_cnt;
|
||||
const int group_stride;
|
||||
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
const short bii;
|
||||
|
||||
const device uint8_t* src;
|
||||
const device T* scales;
|
||||
const device T* biases;
|
||||
threadgroup T* dst;
|
||||
|
||||
AffineScalesPackedQuantizedBlockLoader(
|
||||
const device uint32_t* src_,
|
||||
const device T* scales_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(
|
||||
reduction_dim ? BCOLS_PACKED * bytes_per_pack
|
||||
: BROWS * src_ld * bytes_per_pack / pack_factor),
|
||||
group_step_cnt(0),
|
||||
group_stride(BROWS_PACKED * 2 * src_ld / group_size),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(n_reads * thread_idx / BCOLS_PACKED),
|
||||
bj((n_reads * thread_idx) % BCOLS_PACKED),
|
||||
bii(bi / row_pack_factor),
|
||||
src(((const device uint8_t*)src_) +
|
||||
bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack),
|
||||
scales(
|
||||
scales_ + bii * 2 * src_ld * row_pack_factor / group_size +
|
||||
bi % row_pack_factor),
|
||||
biases(scales + row_pack_factor),
|
||||
dst(dst_ + bi * dst_ld + bj * pack_factor) {}
|
||||
|
||||
void load_unsafe() const {
|
||||
if (bits == 2 && TOTAL_READS < tgp_size && bi >= BROWS) {
|
||||
return;
|
||||
}
|
||||
|
||||
T scale = *scales;
|
||||
T bias = *biases;
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
dequantize<T, pack_factor, bits>(
|
||||
(const device uint8_t*)(src + bytes_per_pack * i),
|
||||
scale,
|
||||
bias,
|
||||
dst + i * pack_factor);
|
||||
}
|
||||
}
|
||||
|
||||
void load_safe(short2 src_tile_dim) const {
|
||||
if (TOTAL_READS < tgp_size && bi >= BROWS) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (reduction_dim == 1 && bii >= src_tile_dim.y) {
|
||||
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||
dst[i] = T(0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (reduction_dim == 0 && bii >= src_tile_dim.x) {
|
||||
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||
dst[i] = T(0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
T scale = scales[i];
|
||||
T bias = biases[i];
|
||||
dequantize<T, pack_factor, bits>(
|
||||
(const device uint8_t*)(src + bytes_per_pack * i * src_ld),
|
||||
scale,
|
||||
bias,
|
||||
dst + i * dst_ld);
|
||||
}
|
||||
}
|
||||
|
||||
void next() {
|
||||
src += tile_stride;
|
||||
if (reduction_dim == 1) {
|
||||
if (group_steps > 1) {
|
||||
group_step_cnt++;
|
||||
if (group_step_cnt == group_steps) {
|
||||
group_step_cnt = 0;
|
||||
scales += (2 * row_pack_factor);
|
||||
biases += (2 * row_pack_factor);
|
||||
}
|
||||
} else {
|
||||
scales += (2 * row_pack_factor);
|
||||
biases += (2 * row_pack_factor);
|
||||
}
|
||||
} else {
|
||||
scales += group_stride;
|
||||
biases += group_stride;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
METAL_FUNC void affine_packed_qmm_t_impl(
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
const device T* x,
|
||||
device T* y,
|
||||
threadgroup T* Xs,
|
||||
threadgroup T* Ws,
|
||||
const constant int& K,
|
||||
const constant int& N,
|
||||
const constant int& M,
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
||||
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr bool power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int WM = 2;
|
||||
constexpr int WN = 2;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int row_pack_factor = (power_of_2_bits) ? 4 : 2;
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
|
||||
// Instantiate the appropriate BlockMMA and Loader
|
||||
using mma_t = mlx::steel::
|
||||
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
|
||||
using loader_x_t =
|
||||
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
||||
using loader_fully_packed_t = AffinePackedQuantizedBlockLoader<
|
||||
T,
|
||||
BN,
|
||||
BK,
|
||||
BK_padded,
|
||||
1,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
bits>;
|
||||
using loader_scales_packed_t = AffineScalesPackedQuantizedBlockLoader<
|
||||
T,
|
||||
BN,
|
||||
BK,
|
||||
BK_padded,
|
||||
1,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
bits>;
|
||||
using loader_w_t = typename ConditionalType<
|
||||
power_of_2_bits,
|
||||
loader_fully_packed_t,
|
||||
loader_scales_packed_t>::type;
|
||||
|
||||
// Set the block
|
||||
const int K_w =
|
||||
(power_of_2_bits) ? K * row_pack_factor / pack_factor : K * bits / 32;
|
||||
const int K_g = K * 2 * row_pack_factor / group_size;
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
const int packed_y_col = tid.x * (BN / row_pack_factor);
|
||||
|
||||
x += y_row * K;
|
||||
w += (power_of_2_bits) ? packed_y_col * K_w : y_col * K_w;
|
||||
scales += packed_y_col * K_g;
|
||||
y += y_row * N + y_col;
|
||||
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
const short num_outs = min(BN, N - y_col);
|
||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||
loader_w_t loader_w(w, scales, K, Ws, simd_gid, simd_lid);
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
if (num_els < BM) {
|
||||
if (!aligned_N && num_outs < BN) {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_safe(short2(BK, num_els));
|
||||
loader_w.load_safe(short2(BK, num_outs));
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(Xs, Ws);
|
||||
loader_x.next();
|
||||
loader_w.next();
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_safe(short2(BK, num_els));
|
||||
loader_w.load_unsafe();
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(Xs, Ws);
|
||||
loader_x.next();
|
||||
loader_w.next();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!aligned_N && num_outs < BN) {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_unsafe();
|
||||
loader_w.load_safe(short2(BK, num_outs));
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
mma_op.mma(Xs, Ws);
|
||||
loader_x.next();
|
||||
loader_w.next();
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
loader_x.load_unsafe();
|
||||
loader_w.load_unsafe();
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
mma_op.mma(Xs, Ws);
|
||||
loader_x.next();
|
||||
loader_w.next();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store results to device memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (num_els < BM || num_outs < BN) {
|
||||
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
|
||||
} else {
|
||||
mma_op.store_result(y, N);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const bool batched,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void affine_packed_qmm_t(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* x [[buffer(2)]],
|
||||
device T* y [[buffer(3)]],
|
||||
const constant int& K [[buffer(4)]],
|
||||
const constant int& N [[buffer(5)]],
|
||||
const constant int& M [[buffer(6)]],
|
||||
const constant int& x_batch_ndims [[buffer(7)]],
|
||||
const constant int* x_shape [[buffer(8)]],
|
||||
const constant int64_t* x_strides [[buffer(9)]],
|
||||
const constant int& w_batch_ndims [[buffer(10)]],
|
||||
const constant int* w_shape [[buffer(11)]],
|
||||
const constant int64_t* w_strides [[buffer(12)]],
|
||||
const constant int64_t* s_strides [[buffer(13)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
(void)lid;
|
||||
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BN * BK_padded];
|
||||
|
||||
if (batched) {
|
||||
adjust_matrix_offsets<T>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
y,
|
||||
M * N,
|
||||
x_batch_ndims,
|
||||
x_shape,
|
||||
x_strides,
|
||||
w_batch_ndims,
|
||||
w_shape,
|
||||
w_strides,
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
@@ -60,6 +60,14 @@
|
||||
bits, \
|
||||
split_k)
|
||||
|
||||
#define instantiate_quantized_affine_packed(name, type, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits)
|
||||
|
||||
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 1) \
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
||||
@@ -96,12 +104,20 @@
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
||||
|
||||
#define instantiate_quantized_all_affine_packed(type, group_size, bits) \
|
||||
instantiate_quantized_affine_packed(affine_packed_qmv_fast, type, group_size, bits) \
|
||||
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, true) \
|
||||
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, true, false) \
|
||||
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, true) \
|
||||
instantiate_quantized_aligned_batched(affine_packed_qmm_t, type, group_size, bits, false, false)
|
||||
|
||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||
instantiate_quantized_all_single(type, group_size, bits) \
|
||||
instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_all_quad(type, group_size, bits) \
|
||||
instantiate_quantized_all_splitk(type, group_size, bits)
|
||||
instantiate_quantized_all_splitk(type, group_size, bits) \
|
||||
instantiate_quantized_all_affine_packed(type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_types(group_size, bits) \
|
||||
instantiate_quantized_funcs(float, group_size, bits) \
|
||||
|
@@ -377,10 +377,187 @@ void qmm_op(
|
||||
s);
|
||||
}
|
||||
|
||||
void affine_packed_qmv(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
int B,
|
||||
int D,
|
||||
int O,
|
||||
int group_size,
|
||||
int bits,
|
||||
const Stream& s) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
auto ensure_row_contiguous_last_dims = [&d, &s](const array& arr) {
|
||||
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
d.add_temporary(arr_copy, s.index);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
auto x = ensure_row_contiguous_last_dims(inputs[0]);
|
||||
auto w = ensure_row_contiguous_last_dims(inputs[1]);
|
||||
auto scales = ensure_row_contiguous_last_dims(inputs[2]);
|
||||
|
||||
const bool pow2_bits = (bits & (bits - 1)) == 0;
|
||||
const int n_simdgroups = 2;
|
||||
const int results_per_simdgroup = (pow2_bits) ? 4 : 2;
|
||||
MTL::Size group_dims(32, n_simdgroups, 1);
|
||||
MTL::Size grid_dims(O / n_simdgroups / results_per_simdgroup, B, 1);
|
||||
|
||||
std::string name;
|
||||
name.reserve(64);
|
||||
concatenate(
|
||||
name,
|
||||
(D % 512 == 0) ? "affine_packed_qmv_fast_" : "affine_packed_qmv_",
|
||||
get_type_string(out.dtype()),
|
||||
"_gs_",
|
||||
std::to_string(group_size),
|
||||
"_b_",
|
||||
std::to_string(bits));
|
||||
auto kernel = get_quantized_kernel(d, name, "");
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(x, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder.set_bytes(D, 4);
|
||||
compute_encoder.set_bytes(O, 5);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void affine_packed_qmm_t(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
bool batched,
|
||||
int B,
|
||||
int D,
|
||||
int O,
|
||||
int group_size,
|
||||
int bits,
|
||||
const Stream& s) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
auto ensure_row_contiguous_last_dims = [&d, &s](const array& arr) {
|
||||
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
d.add_temporary(arr_copy, s.index);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
// TODO: Deal with this in routing towards qmm_n instead of qmm_t
|
||||
auto x = ensure_row_contiguous_last_dims(inputs[0]);
|
||||
auto w = ensure_row_contiguous_last_dims(inputs[1]);
|
||||
auto scales = ensure_row_contiguous_last_dims(inputs[2]);
|
||||
|
||||
int x_batch_ndims = x.ndim() - 2;
|
||||
auto& x_shape = x.shape();
|
||||
auto& x_strides = x.strides();
|
||||
int w_batch_ndims = w.ndim() - 2;
|
||||
auto& w_shape = w.shape();
|
||||
auto& w_strides = w.strides();
|
||||
auto& s_strides = scales.strides();
|
||||
|
||||
const int wn = 2;
|
||||
const int wm = 2;
|
||||
const int bm = 32;
|
||||
const int bn = 32;
|
||||
const int N = (batched) ? out.size() / B / O : 1;
|
||||
MTL::Size group_dims(32, wn, wm);
|
||||
MTL::Size grid_dims((O + bn - 1) / bn, (B + bm - 1) / bm, N);
|
||||
|
||||
std::string name;
|
||||
name.reserve(64);
|
||||
concatenate(
|
||||
name,
|
||||
"affine_packed_qmm_t_",
|
||||
get_type_string(out.dtype()),
|
||||
"_gs_",
|
||||
std::to_string(group_size),
|
||||
"_b_",
|
||||
std::to_string(bits),
|
||||
"_alN_",
|
||||
((O % 32) == 0) ? "true" : "false",
|
||||
"_batch_",
|
||||
(batched) ? "true" : "false");
|
||||
auto kernel = get_quantized_kernel(d, name, "");
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(x, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
compute_encoder.set_bytes(D, 4);
|
||||
compute_encoder.set_bytes(O, 5);
|
||||
compute_encoder.set_bytes(B, 6);
|
||||
if (batched) {
|
||||
compute_encoder.set_bytes(x_batch_ndims, 7);
|
||||
compute_encoder.set_vector_bytes(x_shape, 8);
|
||||
compute_encoder.set_vector_bytes(x_strides, 9);
|
||||
compute_encoder.set_bytes(w_batch_ndims, 10);
|
||||
compute_encoder.set_vector_bytes(w_shape, 11);
|
||||
compute_encoder.set_vector_bytes(w_strides, 12);
|
||||
compute_encoder.set_vector_bytes(s_strides, 13);
|
||||
}
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void affine_packed_qmm_op(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits,
|
||||
const Stream& s) {
|
||||
auto& x = inputs[0];
|
||||
auto& w = inputs[1];
|
||||
bool batched = w.ndim() > 2;
|
||||
int D = x.shape(-1);
|
||||
int O = out.shape(-1);
|
||||
int B = (batched) ? x.shape(-2) : x.size() / D;
|
||||
|
||||
if (transpose) {
|
||||
if (B < 6) {
|
||||
affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s);
|
||||
} else {
|
||||
affine_packed_qmm_t(inputs, out, batched, B, D, O, group_size, bits, s);
|
||||
}
|
||||
} else {
|
||||
}
|
||||
}
|
||||
|
||||
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
qmm_op(
|
||||
inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream());
|
||||
if (type_ == QuantizationType::Affine) {
|
||||
assert(inputs.size() == 4);
|
||||
qmm_op(
|
||||
inputs,
|
||||
out,
|
||||
transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
/*gather=*/false,
|
||||
stream());
|
||||
}
|
||||
|
||||
if (type_ == QuantizationType::AffinePacked) {
|
||||
assert(inputs.size() == 3);
|
||||
affine_packed_qmm_op(inputs, out, transpose_, group_size_, bits_, stream());
|
||||
}
|
||||
}
|
||||
|
||||
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
189
mlx/ops.cpp
189
mlx/ops.cpp
@@ -75,10 +75,35 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits) {
|
||||
int bits,
|
||||
QuantizationType quantization_type) {
|
||||
// Check if we have biases as expected
|
||||
switch (quantization_type) {
|
||||
case QuantizationType::Affine:
|
||||
if (!biases.has_value()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag
|
||||
<< "] The biases argument is required for quantization "
|
||||
<< "type '" << quantization_type << "'";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
break;
|
||||
case QuantizationType::AffinePacked:
|
||||
if (biases.has_value()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Quantization type '" << quantization_type
|
||||
<< "' does not use "
|
||||
<< "biases but biases were provided";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
bool pow2_bits = (bits & (bits - 1)) == 0;
|
||||
|
||||
if (w.dtype() != uint32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] The weight matrix should be uint32 "
|
||||
@@ -86,11 +111,11 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (scales.shape() != biases.shape()) {
|
||||
if (biases.has_value() && scales.shape() != biases.value().shape()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Scales and biases should have the same shape. "
|
||||
<< "Received scales with shape " << scales.shape()
|
||||
<< " and biases with " << biases.shape();
|
||||
<< " and biases with " << biases.value().shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@@ -99,25 +124,42 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag
|
||||
<< "] Weight, scales and biases should have the same batch shape. "
|
||||
<< "Received weight with shape " << w.shape() << ", scales with "
|
||||
<< scales.shape() << " and biases with " << biases.shape();
|
||||
<< "Received weight with shape " << w.shape()
|
||||
<< " and scales/biases with " << scales.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {
|
||||
int weight_dims = w.shape(-1) * 32 / bits;
|
||||
int scales_dims = scales.shape(-1) * group_size;
|
||||
if (quantization_type == QuantizationType::AffinePacked) {
|
||||
if (pow2_bits) {
|
||||
scales_dims /= 8;
|
||||
weight_dims /= 4;
|
||||
} else {
|
||||
scales_dims /= 4;
|
||||
}
|
||||
}
|
||||
|
||||
if (weight_dims != scales_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] The shapes of the weight and scales are "
|
||||
<< "incompatible based on bits and group_size. w.shape() == "
|
||||
<< w.shape() << " and scales.shape() == " << scales.shape()
|
||||
<< " with group_size=" << group_size << " and bits=" << bits;
|
||||
<< "incompatible based on bits, group_size and quantization type. "
|
||||
<< "w.shape() == " << w.shape()
|
||||
<< " and scales.shape() == " << scales.shape()
|
||||
<< " with group_size=" << group_size << ", bits=" << bits
|
||||
<< " and type='" << quantization_type << "'";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int x_inner_dims = x.shape(-1);
|
||||
|
||||
// Calculate the expanded w's dims
|
||||
int w_inner_dims = (transpose) ? w.shape(-1) * 32 / bits : w.shape(-2);
|
||||
int w_outer_dims = (transpose) ? w.shape(-2) : w.shape(-1) * 32 / bits;
|
||||
int weight_dims_other = w.shape(-2);
|
||||
if (quantization_type == QuantizationType::AffinePacked && pow2_bits) {
|
||||
weight_dims_other *= 4;
|
||||
}
|
||||
int w_inner_dims = (transpose) ? weight_dims : weight_dims_other;
|
||||
int w_outer_dims = (transpose) ? weight_dims_other : weight_dims;
|
||||
|
||||
if (w_inner_dims != x_inner_dims) {
|
||||
std::ostringstream msg;
|
||||
@@ -3662,14 +3704,23 @@ array quantized_matmul(
|
||||
array x,
|
||||
array w,
|
||||
array scales,
|
||||
array biases,
|
||||
std::optional<array> biases,
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationType quantization_type /* = QuantizationType::Affine */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Check and extract the quantized matrix shape against x
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||
"quantized_matmul",
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
transpose,
|
||||
group_size,
|
||||
bits,
|
||||
quantization_type);
|
||||
|
||||
// QuantizedMatmul handles w.ndim == 2 case.
|
||||
if (x.ndim() > 2 && w.ndim() > 2) {
|
||||
@@ -3690,69 +3741,134 @@ array quantized_matmul(
|
||||
*(inner_shape.end() - 1) = scales.shape(-1);
|
||||
scales = broadcast_to(scales, inner_shape, s);
|
||||
|
||||
*(inner_shape.end() - 1) = biases.shape(-1);
|
||||
biases = broadcast_to(biases, inner_shape, s);
|
||||
if (biases.has_value()) {
|
||||
*(inner_shape.end() - 1) = biases.value().shape(-1);
|
||||
biases = broadcast_to(biases.value(), inner_shape, s);
|
||||
}
|
||||
}
|
||||
|
||||
auto dtype = result_type(x, scales, biases);
|
||||
auto dtype = result_type(x, scales);
|
||||
if (biases.has_value()) {
|
||||
dtype = promote_types(dtype, biases.value().dtype());
|
||||
}
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] Only real floating types are supported but "
|
||||
<< "the passed types where x.dtype() == " << x.dtype()
|
||||
<< ", scales.dtype() == " << scales.dtype()
|
||||
<< " and biases.dtype() == " << biases.dtype();
|
||||
<< ", scales.dtype() == " << scales.dtype();
|
||||
if (biases.has_value()) {
|
||||
msg << " and biases.dtype() == " << biases.value().dtype();
|
||||
}
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Prepare the inputs vector
|
||||
std::vector<array> inputs;
|
||||
inputs.reserve(4);
|
||||
inputs.push_back(astype(x, dtype, s));
|
||||
inputs.push_back(w);
|
||||
inputs.push_back(astype(scales, dtype, s));
|
||||
if (biases.has_value()) {
|
||||
inputs.push_back(astype(biases.value(), dtype, s));
|
||||
}
|
||||
|
||||
auto out_shape = x.shape();
|
||||
out_shape.back() = w_outer_dims;
|
||||
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
dtype,
|
||||
std::make_shared<QuantizedMatmul>(
|
||||
to_stream(s), group_size, bits, transpose),
|
||||
{astype(x, dtype, s),
|
||||
w,
|
||||
astype(scales, dtype, s),
|
||||
astype(biases, dtype, s)});
|
||||
to_stream(s), quantization_type, group_size, bits, transpose),
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
std::tuple<array, array, array> quantize(
|
||||
std::tuple<array, array, std::optional<array>> quantize(
|
||||
const array& w,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationType quantization_type /* = QuantizationType::Affine */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fast::affine_quantize(w, group_size, bits, s);
|
||||
switch (quantization_type) {
|
||||
case QuantizationType::Affine:
|
||||
return fast::affine_quantize(w, group_size, bits, s);
|
||||
case QuantizationType::AffinePacked: {
|
||||
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
|
||||
|
||||
int pow2_bits = (bits & (bits - 1)) == 0;
|
||||
int row_packing = (pow2_bits) ? 4 : 2;
|
||||
|
||||
scales = unflatten(scales, -2, {-1, row_packing}, s);
|
||||
biases = unflatten(biases, -2, {-1, row_packing}, s);
|
||||
scales = concatenate({scales, biases}, -2, s);
|
||||
scales = moveaxis(scales, -2, -1, s);
|
||||
scales = flatten(scales, -2, -1, s);
|
||||
|
||||
if (pow2_bits) {
|
||||
wq = unflatten(wq, -2, {-1, row_packing}, s);
|
||||
wq = moveaxis(wq, -2, -1, s);
|
||||
wq = flatten(wq, -2, -1, s);
|
||||
}
|
||||
|
||||
return std::make_tuple(wq, scales, std::nullopt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
array dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationType quantization_type /* = QuantizationType::Affine */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fast::affine_dequantize(w, scales, biases, group_size, bits, s);
|
||||
return fast::affine_dequantize(
|
||||
w, scales, biases.value(), group_size, bits, s);
|
||||
}
|
||||
|
||||
array gather_qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
||||
std::optional<array> rhs_indices_ /* = std::nullopt */,
|
||||
const std::optional<array>& biases,
|
||||
const std::optional<array>& lhs_indices_ /* = std::nullopt */,
|
||||
const std::optional<array>& rhs_indices_ /* = std::nullopt */,
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationType quantization_type /* = QuantizationType::Affine */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!lhs_indices_ && !rhs_indices_) {
|
||||
return quantized_matmul(
|
||||
x, w, scales, biases, transpose, group_size, bits, s);
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
transpose,
|
||||
group_size,
|
||||
bits,
|
||||
quantization_type,
|
||||
s);
|
||||
}
|
||||
|
||||
if (quantization_type != QuantizationType::Affine) {
|
||||
std::ostringstream msg;
|
||||
msg << "[gather_qmm] Only quantization type '" << quantization_type
|
||||
<< "' supported";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
||||
"gather_qmm",
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
transpose,
|
||||
group_size,
|
||||
bits,
|
||||
quantization_type);
|
||||
|
||||
// Extract indices and broadcast them
|
||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||
@@ -3768,16 +3884,17 @@ array gather_qmm(
|
||||
out_shape.push_back(w_outer_dims);
|
||||
|
||||
// and output type
|
||||
auto out_type = result_type(x, scales, biases);
|
||||
auto out_type = result_type(x, scales, biases.value());
|
||||
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
|
||||
std::make_shared<GatherQMM>(
|
||||
to_stream(s), quantization_type, group_size, bits, transpose),
|
||||
{astype(x, out_type, s),
|
||||
w,
|
||||
astype(scales, out_type, s),
|
||||
astype(biases, out_type, s),
|
||||
astype(biases.value(), out_type, s),
|
||||
lhs_indices,
|
||||
rhs_indices});
|
||||
}
|
||||
|
18
mlx/ops.h
18
mlx/ops.h
@@ -1277,31 +1277,34 @@ array conv_transpose3d(
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Quantized matmul multiplies x with a quantized matrix w*/
|
||||
/** Quantized matmul multiplies x with a quantized matrix w */
|
||||
array quantized_matmul(
|
||||
array x,
|
||||
array w,
|
||||
array scales,
|
||||
array biases,
|
||||
std::optional<array> biases,
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationType quantization_type = QuantizationType::Affine,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Quantize a matrix along its last axis */
|
||||
std::tuple<array, array, array> quantize(
|
||||
std::tuple<array, array, std::optional<array>> quantize(
|
||||
const array& w,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationType quantization_type = QuantizationType::Affine,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Dequantize a matrix produced by quantize() */
|
||||
array dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationType quantization_type = QuantizationType::Affine,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Compute matrix products with matrix-level gather. */
|
||||
@@ -1309,12 +1312,13 @@ array gather_qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
std::optional<array> lhs_indices = std::nullopt,
|
||||
std::optional<array> rhs_indices = std::nullopt,
|
||||
const std::optional<array>& biases,
|
||||
const std::optional<array>& lhs_indices = std::nullopt,
|
||||
const std::optional<array>& rhs_indices = std::nullopt,
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationType quantization_type = QuantizationType::Affine,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Returns a contraction of a and b over multiple dimensions. */
|
||||
|
@@ -2777,10 +2777,11 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
cotangents[0],
|
||||
primals[1],
|
||||
primals[2],
|
||||
primals[3],
|
||||
(primals.size() > 3) ? std::optional(primals[3]) : std::nullopt,
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
type_,
|
||||
stream()));
|
||||
}
|
||||
|
||||
@@ -2855,6 +2856,7 @@ std::vector<array> GatherQMM::vjp(
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
type_,
|
||||
stream()),
|
||||
-3,
|
||||
stream()),
|
||||
|
@@ -8,6 +8,7 @@
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/io/load.h"
|
||||
#include "mlx/stream.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#define DEFINE_VMAP() \
|
||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
|
||||
@@ -1568,10 +1569,12 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
public:
|
||||
explicit QuantizedMatmul(
|
||||
Stream stream,
|
||||
QuantizationType type,
|
||||
int group_size,
|
||||
int bits,
|
||||
bool transpose)
|
||||
: UnaryPrimitive(stream),
|
||||
type_(type),
|
||||
group_size_(group_size),
|
||||
bits_(bits),
|
||||
transpose_(transpose) {}
|
||||
@@ -1586,6 +1589,7 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
private:
|
||||
QuantizationType type_;
|
||||
int group_size_;
|
||||
int bits_;
|
||||
bool transpose_;
|
||||
@@ -1595,8 +1599,14 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
|
||||
class GatherQMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
|
||||
explicit GatherQMM(
|
||||
Stream stream,
|
||||
QuantizationType type,
|
||||
int group_size,
|
||||
int bits,
|
||||
bool transpose)
|
||||
: UnaryPrimitive(stream),
|
||||
type_(type),
|
||||
group_size_(group_size),
|
||||
bits_(bits),
|
||||
transpose_(transpose) {}
|
||||
@@ -1610,6 +1620,7 @@ class GatherQMM : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
QuantizationType type_;
|
||||
int group_size_;
|
||||
int bits_;
|
||||
bool transpose_;
|
||||
|
@@ -145,6 +145,30 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) {
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, QuantizationType type) {
|
||||
std::string_view quantization_type;
|
||||
switch (type) {
|
||||
case QuantizationType::Affine:
|
||||
quantization_type = "affine";
|
||||
break;
|
||||
case QuantizationType::AffinePacked:
|
||||
quantization_type = "affine-packed";
|
||||
break;
|
||||
}
|
||||
return os << quantization_type;
|
||||
}
|
||||
|
||||
QuantizationType from_string(const std::string& s) {
|
||||
if (s == "affine") {
|
||||
return QuantizationType::Affine;
|
||||
}
|
||||
if (s == "affine-packed") {
|
||||
return QuantizationType::AffinePacked;
|
||||
}
|
||||
|
||||
throw std::invalid_argument("Cannot map '" + s + "' to a quantization type");
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
inline size_t
|
||||
|
13
mlx/utils.h
13
mlx/utils.h
@@ -100,6 +100,19 @@ inline int next_power_of_2(int n) {
|
||||
return pow(2, std::ceil(std::log2(n)));
|
||||
}
|
||||
|
||||
/** Enumerate the different quantization types */
|
||||
enum class QuantizationType {
|
||||
// Traditional affine quantization with separate scales and biases
|
||||
Affine,
|
||||
|
||||
// The same quantization as affine but with the scales and biases packed in a
|
||||
// single array and contiguously every 4 rows
|
||||
AffinePacked,
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, QuantizationType type);
|
||||
QuantizationType from_string(const std::string& s);
|
||||
|
||||
namespace env {
|
||||
|
||||
int get_var(const char* name, int default_value);
|
||||
|
@@ -39,6 +39,10 @@ class Embedding(Module):
|
||||
"""
|
||||
return x @ self.weight.T
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
def to_quantized(
|
||||
self, group_size: int = 64, bits: int = 4, quantization_type: str = "affine"
|
||||
):
|
||||
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
|
||||
return QuantizedEmbedding.from_embedding(self, group_size, bits)
|
||||
return QuantizedEmbedding.from_embedding(
|
||||
self, group_size, bits, quantization_type
|
||||
)
|
||||
|
@@ -70,9 +70,11 @@ class Linear(Module):
|
||||
x = x @ self["weight"].T
|
||||
return x
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||
def to_quantized(
|
||||
self, group_size: int = 64, bits: int = 4, quantization_type: str = "affine"
|
||||
):
|
||||
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
|
||||
return QuantizedLinear.from_linear(self, group_size, bits)
|
||||
return QuantizedLinear.from_linear(self, group_size, bits, quantization_type)
|
||||
|
||||
|
||||
class Bilinear(Module):
|
||||
|
@@ -12,6 +12,7 @@ def quantize(
|
||||
model: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
quantization_type: str = "affine",
|
||||
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
|
||||
):
|
||||
"""Quantize the sub-modules of a module according to a predicate.
|
||||
@@ -39,7 +40,11 @@ def quantize(
|
||||
if bool_or_params := class_predicate(path, m):
|
||||
if hasattr(m, "to_quantized"):
|
||||
if isinstance(bool_or_params, bool):
|
||||
return m.to_quantized(group_size=group_size, bits=bits)
|
||||
return m.to_quantized(
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
quantization_type=quantization_type,
|
||||
)
|
||||
elif isinstance(bool_or_params, dict):
|
||||
return m.to_quantized(**bool_or_params)
|
||||
else:
|
||||
@@ -131,9 +136,15 @@ class QuantizedEmbedding(Module):
|
||||
|
||||
@classmethod
|
||||
def from_embedding(
|
||||
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
|
||||
cls,
|
||||
embedding_layer: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
quantization_type: str = "affine",
|
||||
):
|
||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||
if quantization_type != "affine":
|
||||
raise ValueError(f"Quantization type {quantization_type} not supported")
|
||||
embedding_dims, dims = embedding_layer.weight.shape
|
||||
ql = cls(embedding_dims, dims, group_size, bits)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
@@ -170,12 +181,14 @@ class QuantizedLinear(Module):
|
||||
bias: bool = True,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
quantization_type: str = "affine",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
self.quantization_type = quantization_type
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
@@ -184,7 +197,9 @@ class QuantizedLinear(Module):
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
|
||||
self.weight, self.scales, self.biases = mx.quantize(
|
||||
weight, group_size, bits, quantization_type=quantization_type
|
||||
)
|
||||
|
||||
# And bias if needed
|
||||
if bias:
|
||||
@@ -212,22 +227,29 @@ class QuantizedLinear(Module):
|
||||
x,
|
||||
self["weight"],
|
||||
scales=self["scales"],
|
||||
biases=self["biases"],
|
||||
biases=self.get("biases", None),
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
quantization_type=self.quantization_type,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + self["bias"]
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
||||
def from_linear(
|
||||
cls,
|
||||
linear_layer: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
quantization_type: str = "affine",
|
||||
):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits, quantization_type)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, group_size, bits
|
||||
linear_layer.weight, group_size, bits, quantization_type=quantization_type
|
||||
)
|
||||
if "bias" in linear_layer:
|
||||
ql.bias = linear_layer.bias
|
||||
|
@@ -4018,18 +4018,38 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"quantized_matmul",
|
||||
&mx::quantized_matmul,
|
||||
[](mx::array x,
|
||||
mx::array w,
|
||||
mx::array scales,
|
||||
std::optional<mx::array> biases,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits,
|
||||
const std::string& quantization_type,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::quantized_matmul(
|
||||
std::move(x),
|
||||
std::move(w),
|
||||
std::move(scales),
|
||||
std::move(biases),
|
||||
transpose,
|
||||
group_size,
|
||||
bits,
|
||||
mx::from_string(quantization_type),
|
||||
s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"biases"_a = nb::none(),
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"quantization_type"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array], transpose: bool = True, group_size: int = 64, bits: int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||
quantization uses one floating point scale and bias per ``group_size`` of
|
||||
@@ -4040,7 +4060,8 @@ void init_ops(nb::module_& m) {
|
||||
x (array): Input array
|
||||
w (array): Quantized matrix packed in unsigned integers
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
biases (array, optional): The biases to use per ``group_size``
|
||||
elements of ``w`` depending on the quantization type
|
||||
transpose (bool, optional): Defines whether to multiply with the
|
||||
transposed ``w`` or not, namely whether we are performing
|
||||
``x @ w.T`` or ``x @ w``. Default: ``True``.
|
||||
@@ -4048,20 +4069,30 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
quantization_type (str, optional): The type of quantization used for the matrix.
|
||||
It can be 'affine' or 'affine-packed'.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"quantize",
|
||||
&mx::quantize,
|
||||
[](const mx::array& w,
|
||||
int group_size,
|
||||
int bits,
|
||||
const std::string& quantization_type,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::quantize(
|
||||
w, group_size, bits, mx::from_string(quantization_type), s);
|
||||
},
|
||||
nb::arg(),
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"quantization_type"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, Optional[array]]"),
|
||||
R"pbdoc(
|
||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||
|
||||
@@ -4103,26 +4134,46 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. Default: ``4``.
|
||||
quantization_type (str, optional): The type of quantization used for the matrix.
|
||||
It can be 'affine' or 'affine-packed'.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing
|
||||
|
||||
* w_q (array): The quantized version of ``w``
|
||||
* scales (array): The scale to multiply each element with, namely :math:`s`
|
||||
* biases (array): The biases to add to each element, namely :math:`\beta`
|
||||
* biases (array, optional): The biases to add to each element, namely
|
||||
* :math:`\beta`. Depending on the quantization type this return value
|
||||
may be None.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"dequantize",
|
||||
&mx::dequantize,
|
||||
[](const mx::array& wq,
|
||||
const mx::array& scales,
|
||||
const std::optional<mx::array>& biases,
|
||||
int group_size,
|
||||
int bits,
|
||||
const std::string& quantization_type,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::dequantize(
|
||||
wq,
|
||||
scales,
|
||||
biases,
|
||||
group_size,
|
||||
bits,
|
||||
mx::from_string(quantization_type),
|
||||
s);
|
||||
},
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"quantization_type"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def dequantize(w: array, /, scales: array, biases: Optional[array], group_size: int = 64, bits: int = 4, quantization_type: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dequantize the matrix ``w`` using the provided ``scales`` and
|
||||
``biases`` and the ``group_size`` and ``bits`` configuration.
|
||||
@@ -4143,6 +4194,8 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
quantization_type (str, optional): The type of quantization used for the matrix.
|
||||
It can be 'affine' or 'affine-packed'.
|
||||
|
||||
Returns:
|
||||
array: The dequantized version of ``w``
|
||||
@@ -4153,16 +4206,17 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"scales"_a,
|
||||
"biases"_a,
|
||||
"biases"_a = nb::none(),
|
||||
"lhs_indices"_a = nb::none(),
|
||||
"rhs_indices"_a = nb::none(),
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"quantization_type"_a = "affine",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array], lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform quantized matrix multiplication with matrix-level gather.
|
||||
|
||||
@@ -4188,6 +4242,8 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. Default: ``64``.
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. Default: ``4``.
|
||||
quantization_type (str, optional): The type of quantization used for the matrix.
|
||||
It can be 'affine' or 'affine-packed'.
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``
|
||||
|
Reference in New Issue
Block a user